## LZ78 Python Interface for Denoising with Lookahead

### Setup
Remember to do `cd crates/python/ && maturin develop && cd ../..` from the `lz78_rust` directory after making any changes to the Rust or the Python interface!

In [123]:

import numpy as np
from sys import stdout
import numpy as np
import matplotlib.pyplot as plt

import scipy.signal as signal
from scipy.linalg import toeplitz
import dill as pickle
from lz78 import Sequence, LZ78Encoder, CharacterMap, BlockLZ78Encoder, LZ78SPA
from lz78 import encoded_sequence_from_bytes, spa_from_bytes

### Small Usage Example
This is equivalent to the traversal and lookahead tests in `spa.rs`.

In [124]:
spa = LZ78SPA(2, gamma=0.5)
input = Sequence([1, 0] * 1000, alphabet_size=2)
train_loss = spa.train_on_block(input) / 2000
train_loss

0.17043380775717953

In [125]:
pdf_traverse = spa.traverse_and_get_prob([0, 1])
pdf_traverse

[0.9883720930232558, 0.011627906976744186]

In [126]:
pdf_lookahead = spa.traverse_and_get_prob_with_lookahead(input=[1, 0, 1], lookahead=[1, 0])
pdf_lookahead

[1.0, 0.0]

In [127]:
def universal_denoiser_naive(signal_to_filter, spa_tree, char_map, window_len,pi_matrix, x_alphabet):
    sequence_l = len(signal_to_filter)
    
    Zt_index = signal_to_filter[window_len:sequence_l]
    print(Zt_index[:10  ])
    filtered_signal = []
    for il in range(sequence_l-window_len):
        prob = spa_tree.traverse_and_get_prob(signal_to_filter[il:il+window_len])
        PX_t_Z_tm1_ele=np.matmul(np.linalg.inv(np.transpose(pi_matrix)),prob)
        alphabet_ind = Zt_index[il]
        num = pi_matrix[:,alphabet_ind] * PX_t_Z_tm1_ele
        den = prob[alphabet_ind]
        pxt = num/den
        pxt = pxt/np.sum(pxt)
        xt_hat = np.dot(pxt,x_alphabet)
        filtered_signal.append(xt_hat)
    
    return filtered_signal

def get_prob_from_MC(Dd, seed_str, spa_tree, char_map,num_exper):
    generated_seq = []
    for i_mc in range(num_exper):
        output, loss = spa_tree.generate_data(Dd, seed_data = seed_str,temperature =1, min_context= 3)
        generated_seq.append(output)
    # count the number of times of last element in the generated sequence
    prob = np.zeros(3)
    for i in range(num_exper):
        # charmap the generated sequence
        generated_ind = char_map.str_to_symbol_list(generated_seq[i][-1])
        #print(generated_ind)    
        prob[generated_ind] += 1
    prob = prob/num_exper
    return prob

def universal_denoiser_delay_mc(Delay,signal_to_filter_mc, spa_tree, char_map, 
                                window_len,pi_matrix, x_alphabet, num_exp):
    sequence_len = len(signal_to_filter_mc) 
    Zt_index = char_map.str_to_symbol_list(signal_to_filter_mc[window_len:sequence_len])
    filtered_signal_delay_mc = []
    for idmc in range(sequence_len-window_len):
        Zt_seed_str = signal_to_filter_mc[idmc:window_len+idmc-Delay]
        prob_mc = get_prob_from_MC(Delay, Zt_seed_str, spa_tree, num_exp)
        
        PX_t_Z_tm1_mc = np.matmul(np.linalg.inv(np.transpose(pi_matrix)),prob_mc)
        alphabet_ind = Zt_index[idmc]
        #print(alphabet_ind)
        num = pi_matrix[:,alphabet_ind] * PX_t_Z_tm1_mc
        den = prob_mc[alphabet_ind]
        pxt = num/den
        pxt = pxt/np.sum(pxt)
        xt_hat_mc = np.dot(pxt,x_alphabet)
        #print(xt_hat_mc)
        filtered_signal_delay_mc.append(xt_hat_mc)
    
    #print(filtered_signal_delay_mc)

    return filtered_signal_delay_mc

In [128]:
def universal_denoiser_lookahead(Lookahead,signal_to_filter_la, spa_tree, char_map: CharacterMap, 
                                window_len,pi_matrix, x_alphabet):

    sequence_len = len(signal_to_filter_la) 
    Zt_index =signal_to_filter_la[window_len:sequence_len]
    
    filtered_signal_lookahead = []
    #signal_to_filter_la = np.array(signal_to_filter_la)
    for ila in range(sequence_len-window_len):
        Zt_seed_str = signal_to_filter_la[ila:window_len+ila]
        Zt_lookahead_str = signal_to_filter_la[ila+window_len:ila+window_len+Lookahead]

        Zt_seed_vec = Zt_seed_str
        Zt_lookahead_vec = Zt_lookahead_str
        prob_la = spa_tree.traverse_and_get_prob_with_lookahead(Zt_seed_vec, Zt_lookahead_vec)
        
        PX_t_Z_tm1_la = np.matmul(np.linalg.inv(np.transpose(pi_matrix)),prob_la)
        alphabet_ind = Zt_index[ila]
        num = pi_matrix[:,alphabet_ind] * PX_t_Z_tm1_la
        den = prob_la[alphabet_ind]
        pxt = num/den
        pxt = pxt/np.sum(pxt)
        xt_hat_la = np.dot(pxt,x_alphabet)
        filtered_signal_lookahead.append(xt_hat_la)
    
    return filtered_signal_lookahead    

## Markov Example 

In [129]:
def generate_markov_sequence_random_initial(N, p):
    """
    Generates a Markov sequence of length N with two states: -1 and 1.
    The initial state is randomly chosen with a 50:50 chance.
    The probability of switching states is p, and the probability of staying in the same state is 1 - p.
    
    Parameters:
    N (int): Length of the Markov sequence.
    p (float): Probability of switching states.
    
    Returns:
    np.ndarray: The generated Markov sequence.
    """
    # Initialize the first state randomly with a 50:50 chance
    initial_state = np.random.choice([1, -1])
    
    # Initialize the sequence
    sequence = np.zeros(N)
    sequence[0] = initial_state

    # Generate the Markov sequence
    for i in range(1, N):
        if np.random.rand() < p:
            sequence[i] = -sequence[i-1]  # Switch state
        else:
            sequence[i] = sequence[i-1]  # Stay in the same state

    return sequence


In [130]:
## this initializes the sequence that trains the SPA
#N = 1_000_000_000   #--> between 3^18 and 3^19  
N =  1_000_000_000 
D = 2
p = 0.8 # probability of changing state
Xt= generate_markov_sequence_random_initial(N, p)
Nt = np.random.choice([-1, 1], size=N, p=[0.5, 0.5])
Zt=  Xt + Nt + 2
# convert Zt to integers
Zt = Zt.astype(int)
Zt_str = ''.join(str(num) for num in Zt)

In [None]:
stdout.flush() ## add for any long process
spa = LZ78SPA(3, gamma=0.5)
charmap_str = "000022224444"  ## IMPORTANT: do specify the sequence of the alphabet you want to train on. I'll make this more robust in the next update.
char_map =  CharacterMap(charmap_str)

#train on shifted input
for shift in range(5):
    input = Sequence(Zt_str[shift:], charmap =char_map)
    train_loss = spa.train_on_block(input) / N
train_loss

1.4459596367035228

In [132]:
stdout.flush()
bytes = spa.to_bytes()

with open("test_data/trained_markov_spa.bin", 'wb') as file:
    file.write(bytes)

In [52]:
# load the trained spa
with open("trained_markov_spa.bin", 'rb') as file:
    spa = spa_from_bytes(file.read())



In [133]:
pi_matrix_true = np.array([[0.5,0.5,0],[1/2,0,1/2],[0, 0.5,0.5]])
print(pi_matrix_true)


[[0.5 0.5 0. ]
 [0.5 0.  0.5]
 [0.  0.5 0.5]]


In [15]:
len_alphabet = input.alphabet_size() # 3
pi_matrix_approx = np.zeros((len_alphabet, len_alphabet))
# approximate the transition matrix from Xt to Zt
alphabet_Xt = [-1,0, 1]

def compute_pi_matrix(Xt,Zt_str,alphabet, alphabet_Xt):
    pi_matrix = np.ones((len(alphabet_Xt), len(alphabet)))
    for i in range(len(alphabet)):
        for j in range(len(alphabet)):
            # count the number of times alphabet[i] is followed by alphabet[j]
            count = 1
            for k in range(len(Xt)-1):
                if Xt[k] == alphabet_Xt[i] and Zt_str[k] == alphabet[j]:
                    count += 1
            pi_matrix[i,j] = count
    
    pi_matrix = pi_matrix / np.sum(pi_matrix, axis=1)[:,np.newaxis]
    return pi_matrix
pi_matrix_approx = compute_pi_matrix(Xt,Zt_str,['-2','0','2'], alphabet_Xt)
pi_matrix_approx


array([[0.33333333, 0.33333333, 0.33333333],
       [0.33333333, 0.33333333, 0.33333333],
       [0.33333333, 0.33333333, 0.33333333]])

In [207]:
stdout.flush()
LAH =1# lookahead
WL = 5 # window length
start = 11000
stop_seq = 12000
alphabet_Xt = [-1,0, 1]
print(Zt_str[start:stop_seq])
Zt_to_filter = char_map.encode(Zt_str[start:stop_seq])
print(Zt_to_filter)
#Zt_encoded = char_map.encode("".join(Zt_to_filter))
#Xt_hat = universal_denoiser_naive(Zt_to_filter, spa, char_map,WL, pi_matrix_true, alphabet_Xt)
Xt_hat = universal_denoiser_lookahead(LAH,Zt_to_filter, spa, char_map,WL, pi_matrix_true, alphabet_Xt)

2220220202224202020402222402042020442240022242240042204242242004002022224020240204224222004004220402240402424422240204220422202024022404240404220222204020224022222240220422242442424200202202404422220204422024224022224042422040222224022402222022224240204240422240224240222424024204220402024224222422202242442040222224202220424022224424022204020224202222200422042402402424222002022404202202204220420024220222042240222242422220222240242020420442440202024224240400404240424220220422020220222040202202020204022022224222404224220222220222020422220222200420422042422202240224242422020402220240224022224022404220404222424004222420440242004404022222200002040220420022020402240222202424202242424222240222242424240220422404420042404224240424222240402440222402240222002422240224222240220202220204022422402202220420240200222422004022240404242242222224222222224422224220204220220440424040224222024022042040240402204004242022424242404222202224220424242200404242420422242202242040404202404004224224200202040420222242

In [208]:
X_clean = Xt[start+WL:stop_seq]
# calculate the mean squared error on not-nan values
MSE_loss = 0 
k = 0
for i in range(len(Xt_hat)):
    if not np.isnan(Xt_hat[i]):
        loss_ind = (Xt_hat[i] - X_clean[i])**2
        if loss_ind <4:
            MSE_loss += loss_ind
            k += 1
        else:
            print(Xt_hat[i], X_clean[i])
#MSE_loss = MSE_loss/X_clean.shape[0]
MSE_loss = MSE_loss/k
MSE_loss

0.41046149000949395

In [166]:
print(Xt_hat)

[-1.0028757880765402, 0.9988846132396411, -1.0053230031770806, 0.3068248801865758, -0.9895710928319624, 0.4358658621680418, -0.0024763698510690824, 1.0093524049041178, -0.777767822477078, -0.9371284185493458, 0.9919210053859961, -0.9973632665360855, 0.7824819781412293, 1.0152662432828525, -1.0310421286031044, 0.9946106381316879, -0.7195978929073374, 0.6372583326043216, 1.0219711236660385, -0.7242989047266974, 0.640334058797112, 0.9729638701775872, -0.9871811306242789, 0.5130308116153139, -0.15328889871645024, -0.1553875458367247, -1.0128057806170148, 0.6017389410813163, -0.3615734679882801, 0.21192143560564608, 0.0006612817585427422, 0.0006612817585427422, 0.07854639379782236, -0.1312643570390512, 0.21612840466926064, -0.35485847465185816, -1.0128057806170148, 0.5464994435825161, -0.24353331631350533, -0.0010511118990932933, 0.3606855894202536, 1.0031583380123525, -1.001634375371449, 1.013627465743111, -1.0069665752612467, 0.9979972904517876, -1.0061811534946368, 1.0057759247217228, -0

In [102]:
X_clean

array([-1.,  1., -1., -1.,  1., -1.,  1., -1.])

In [191]:
stdout.flush()
LAH =2 # lookahead
WL = 13 # window length
alphabet_Xt = [-1,0, 1]
#print(Zt_str[11000:12000])
Zt_to_filter = char_map.encode(Zt_str[110000:120000])
#print(Zt_to_filter)
#Zt_encoded = char_map.encode("".join(Zt_to_filter))
Xt_hat = universal_denoiser_naive(Zt_to_filter, spa, char_map,WL, pi_matrix_true, alphabet_Xt)
#Xt_hat = universal_denoiser_lookahead(LAH,Zt_to_filter, spa, char_map,WL, pi_matrix_true, alphabet_Xt)

[2, 0, 2, 2, 1, 1, 1, 1, 1, 2]


In [192]:
X_clean = Xt[110000+WL:120000]
# calculate the mean squared error on not-nan values
MSE_loss = 0 
k = 0
for i in range(len(Xt_hat)):
    if not np.isnan(Xt_hat[i]):
        loss_ind = (Xt_hat[i] - X_clean[i])**2

        MSE_loss += loss_ind
        
MSE_loss = MSE_loss/X_clean.shape[0]

MSE_loss

0.4207241235836166