In [None]:
import os
os.chdir("../")

import semiolog as slg

semiotic = slg.Cenematic("fr_wiki")

import regex as re
import numpy as np
from collections import Counter
from scipy.sparse import csr_matrix, lil_matrix, coo_matrix
from tqdm.notebook import tqdm, trange
from functools import partial
from functools import reduce
import operator

from pyinstrument import Profiler
import sys

import time
from pyinstrument import Profiler
import sys

from temp import findall_contexts, findall_contexts_list, find_best_pair, agglutinate_list

def build_nb(
    corpus = None,
    voc_final_length = -30,
    # save = False,
    # save_step = None,
    # progress_bar = True,
    # resume_merges = False,
    parallel = False,
    sparse = True,
    sparse_mode = "csr",
    cpu_count = 4,
    corpus_length = None,
    normalizer = None,
):
    def agglutinate_chain(pair, cl_chain):
        bigram = re.escape(" ".join(pair))
        p = re.compile(r"(?<!\S)" + bigram + r"(?!\S)")
        cl_chain = p.sub("".join(pair), cl_chain)
        return cl_chain

    def extract_drc(pairs, encoder: dict):
        data = []
        rows = []
        columns = []
        for (r,c),d in pairs:
            data.append(d)
            rows.append(encoder[r])
            columns.append(encoder[c])
        return data, rows, columns

    def parallel_chain(chain, n_of_parts, overlap = 0):
        """
        Breaks the chain in n chunks to compute best pair of terms. Chunks are overlapping by one term, so as no pair of terms is lost due to the break.
        """
        if not isinstance(chain,list):
            chain = list(chain)
        chunk_size = int(len(chain) / n_of_parts)+1
        for i in range(0, len(chain), chunk_size):
            yield chain[i : i + chunk_size + overlap]

    def separate_chain(chain, n_of_parts, best_pair: list):
        """
        Separate a chain (in list form) for parallel processing of regex findall of pair, taking care that the cuts of the chunks don't fall in the neiborhood of the pair, affecting the final counts
        """
        chunk_size = int(len(chain) / n_of_parts)+1
        b = 0
        n = chunk_size
        chain_len = len(chain)
        for i in range(n_of_parts):
            n = (i+1)*chunk_size
            if chain_len > n:
                while chain[n-2:n] == best_pair or chain[n-1:n+1] == best_pair:
                    n = n+1
            yield ("[SEP_i] " if i!=0 else "") + " ".join(chain[b:n]) + (" [SEP_i]" if i!=n_of_parts-1 else "")
            b = n-1
        
        
    # normalizer = eval(f"slg.syntagmatic.tokenizer.normalizers.Sequence({semiotic.config.vocabulary.normalizer})")
    
    if parallel:
        
        par_corpus = parallel_chain(corpus[:corpus_length], cpu_count)

        result = slg.util.multiprocessing_tqdm(partial(semiotic.vocab.chain_list_alpha, normalizer), par_corpus, cores=cpu_count, desc="Normalize & Alphabet")
        
        chain_list = []
        alphabet = Counter()
        for chain_l, alpha in result:
            chain_list += chain_l
            alphabet += alpha
            
    else:
        chain_list, alphabet = semiotic.vocab.chain_list_alpha(normalizer, semiotic.corpus.train[:corpus_length], progress_bar=True)

    # cl, alphabet = semiotic.vocab.chain_list_alpha(normalizer, semiotic.corpus.train, progress_bar=True)
    cl_chain = "[SEP] "+" ".join(chain_list)+" [SEP]"
    encode = {k:i for i,(k,v) in enumerate(alphabet.most_common())}
    decode = {i:k for k,i in encode.items()}
    new_i = len(encode)
    if parallel:
        
        par_chain = parallel_chain(chain_list, cpu_count, overlap=1)
        
        result = slg.util.multiprocessing(find_best_pair, par_chain, cores=cpu_count) 
                            
        pairs = reduce(operator.add, result)
        pairs = pairs.most_common()
        
    else:
        pairs = find_best_pair(chain_list).most_common()
    if voc_final_length<0:
        voc_final_length = new_i + abs(voc_final_length)
        
    if sparse:
        data, rows, columns = extract_drc(pairs,encode)
        voc_matrix = coo_matrix((np.array(data), (np.array(rows),np.array(columns))), shape=(voc_final_length, voc_final_length), dtype=int)

    else:
        voc_matrix = np.zeros((voc_final_length, voc_final_length), dtype=int)
        for (row,column),value in pairs:
            voc_matrix[encode[row], encode[column]] = value
    merges = []
    delta_voc = voc_final_length - new_i
    best_pair = "init"
    pair_count = "---"

    t = trange(delta_voc) #, disable = not progress_bar)

    for _ in t:
        t.set_description(f"Pair: {best_pair}, {pair_count}")
        t.refresh()

        if sparse:
            max_i = voc_matrix.data.argmax()
            pair_row = voc_matrix.row[max_i]
            pair_col = voc_matrix.col[max_i]
            pair_count = voc_matrix.data[max_i]
        else:
            pair_row,pair_col = np.unravel_index(np.argmax(voc_matrix, axis=None), voc_matrix.shape)
            pair_count = voc_matrix[pair_row,pair_col]
        
        if pair_count == 0:
            break

        best_pair = (decode[pair_row], decode[pair_col])
        best_pair_string = " ".join(best_pair)
        merges.append(best_pair_string)
        best_pair_string_voc = "".join(best_pair)
        re_voc_l = "("+"|".join([" "+k+" " for k in encode.keys()]+["\[SEP\] ","\[SEP_i\] "])+")"
        re_voc_r = "("+"|".join([" "+k+" " for k in encode.keys()]+[" \[SEP\]"," \[SEP_i\]"])+")"
        if parallel:
            result = slg.util.multiprocessing(
                partial(findall_contexts,best_pair_string=best_pair_string,re_voc_l=re_voc_l,re_voc_r=re_voc_r),
                separate_chain(cl_chain.split(), cpu_count, list(best_pair)),
                cores = cpu_count
                )
            merge_context = reduce(operator.add, result)
        else:
            merge_context = re.findall(re_voc_l+best_pair_string+re_voc_r, cl_chain, overlapped=True)
        merge_context_count_l = Counter()
        merge_context_count_r = Counter()
        for l,r in merge_context:
            if "[SEP]" not in l:
                merge_context_count_l[encode[l.strip()]] += 1
            if "[SEP]" not in r:
                merge_context_count_r[encode[r.strip()]] += 1
        
        if sparse:
            # Convert matrix to CSR or LIL, for item attribution and arithmetic 
            if sparse_mode == "csr":
                voc_matrix = voc_matrix.tocsr()
            else:
                voc_matrix = voc_matrix.tolil()
        
        for row,key in merge_context_count_l.items():
            voc_matrix[row,new_i] = key
            
        for column,key in merge_context_count_r.items():
            voc_matrix[new_i,column] = key

        # Correct previous counts
        
        # compute #(l,r)-(l,r)
        pair_pair_count = len(re.findall(" "+best_pair_string+" "+best_pair_string+" ", cl_chain, overlapped=False))
        # remove #(l,r)-(l,r) from (l,r)-l
        voc_matrix[new_i,pair_row] -= pair_pair_count
        # remove #(l,r)-(l,r) from r-(l,r)
        voc_matrix[pair_col,new_i] -= pair_pair_count
        # remove #(l,r)-(l,r) from r-l
        voc_matrix[pair_col,pair_row] -= pair_pair_count
        # substract (l,r)- from r-
        voc_matrix[pair_col,:new_i] -= voc_matrix[new_i,:new_i]
        # substract -(l,r)- from -l
        voc_matrix[:new_i,pair_row] -= voc_matrix[:new_i,new_i]
        
        # set l-r to 0
        voc_matrix[pair_row,pair_col] = 0
        # register #(l,r)-(l,r)
        voc_matrix[new_i,new_i] = pair_pair_count
        
        if sparse:
            # Convert matrix back to COO, to restart the loop
            voc_matrix = voc_matrix.tocoo()
        
        best_pair_string_voc = "".join(best_pair)
        encode[best_pair_string_voc] = new_i
        decode[new_i] = best_pair_string_voc
        new_i += 1
        cl_chain = agglutinate_chain(best_pair_string.split(),cl_chain)


    if sparse:
        freq_values = voc_matrix.sum(axis=1).T.tolist()[0]
    else:
        freq_values = voc_matrix.sum(axis=1).T.tolist()
    vocabulary = {decode[i]:v for i,v in enumerate(freq_values) if v>0} # Make sure dimension of matrix and size of voc coincide
    vocabulary = sorted(vocabulary.items(), key=lambda x: x[1], reverse=True)
    
    return merges, vocabulary

normalizer = slg.syntagmatic.tokenizer.normalizers.Sequence(semiotic.config.vocabulary.normalizer)
    
profile = Profiler()
profile.start()

merges, vocabulary = build_nb(
    semiotic.corpus.train,
    normalizer=normalizer,
    voc_final_length = -30,
    parallel=True)

profile.stop()
print(profile.output_text(unicode=True, color=True))

print(merges)

print(vocabulary[:100])

In [None]:
# Last build vocabulary before first used version (0.1)


    def build_old(
        self,
        corpus = None,
        vocab_size = None,
        special_tokens = None,
        save = False,
        save_step = None,
        truncate_best_size = None,
        progress_bar = True,
        resume_merges = False,
        parallel = False,
        corpus_length = None
        ):

        if corpus == None:
            corpus = self.name
        
        if vocab_size == None:
            vocab_size = self.config.size

        if truncate_best_size == None:
            truncate_best_size = self.config.truncate_best_size

        if special_tokens == None:
            special_tokens = self.config.special_tokens
        
        if corpus_length == None:
            corpus_length = self.corpus.train_len
        
        if save == True and save_step != None:
            saveQ = True
            
            if not isdir(self.path):
                makedirs(self.path)
                
        else:
            saveQ = False

        def pre_process(corpus_chunk, normalizer):
            # Normalize
            chain_zip = normalizer(corpus_chunk)
            # Build list of pairs
            chain_zip = list(zip(chain_zip,chain_zip[1:]))
            # Create a lookup table of all the positions where a pair appears in a corpus
            pair_pos = defaultdict(set)
            for i,k in list(enumerate(chain_zip)):
                pair_pos[k].add(i)
            # From the previous lookup table, create another lookup table of the frequency of each pair (given by the size of the set of its positions)
            pair_len = Counter()
            for k,pos in pair_pos.items():
                pair_len[k] = len(pos)
            
            return (chain_zip, pair_pos, pair_len)

        def process_best_pair(job_data, best_pair):
            chain_zip, pair_pos, pair_len = job_data
            chain_zip_len = len(chain_zip)

            for i in pair_pos[best_pair]:
                # Skip iteration if position corresponds to a modified set of positions during the iteration. This can happen if there is overlap of pairs, such as "000", where ("0","0") has itself as right pair. Note that, due to unordered implementation of sets, this entails a lack of systematicity in overlapping cases: "000" can be counted randomly as ("00","0") or ("0","00").
                # TODO: Investigate the cost of ordering sets. In which case, the following "if" condition might only be needed for right pairs.
                if chain_zip[i]!=best_pair:
                    continue
                ## merge best pair with left unit
                left_pair_i = i-1
                while left_pair_i>=0 and chain_zip[left_pair_i] == None: # if left pair is within chain limits but empty (= None) because already merged previously, shift to the left
                    left_pair_i -= 1
                if left_pair_i>-1: # proceed only if a left pair was found on the left
                    # Remove from left pair positions, the current position (of the pair to be merged)
                    left_pair = chain_zip[left_pair_i]
                    # Skip update of left_pair position set if left_pair = best_pair, to avoid modification of iterating set. This can happen if there is overlap of pairs. No consequences on final result (right?) since right after the loop, the key corresponding to the best pair is deleted, and chain_zip is indeed updated so the problematic cases can be captured at the beginning of the loop.
                    if left_pair != best_pair:
                        left_pair_pos = pair_pos[left_pair]
                        left_pair_pos.discard(left_pair_i)
                    new_pair = (left_pair[0],"".join(best_pair)) # construct new left pair
                    pair_pos[new_pair].add(left_pair_i) # add new pair (if non existing) and its position to the pair_pos lookup table
                    # update the counts in the pair_len lookuptable
                    pair_len[left_pair] -= 1
                    pair_len[new_pair] += 1
                    # update the list of pairs
                    chain_zip[left_pair_i] = new_pair

                ## merge best pair with right unit.
                # Code is symmetric to left_pair but on the right. Comments are omitted
                right_pair_i = i+1
                while right_pair_i<chain_zip_len and chain_zip[right_pair_i] == None:
                    right_pair_i += 1
                if right_pair_i<chain_zip_len:
                    right_pair = chain_zip[right_pair_i]
                    if right_pair != best_pair:
                        right_pair_pos = pair_pos[right_pair]
                        right_pair_pos.discard(right_pair_i)
                    new_pair = ("".join(best_pair), right_pair[1])
                    pair_pos[new_pair].add(right_pair_i)
                    pair_len[right_pair] -= 1
                    pair_len[new_pair] += 1
                    chain_zip[right_pair_i] = new_pair

                # Empty best pair position in list of pairs
                chain_zip[i] = None

            # Remove best pair from lookuptables
            del pair_pos[best_pair]
            del pair_len[best_pair]

            return (chain_zip, pair_pos, pair_len)

        def compute_freq(chain_zip):
            # TODO: add the last unit to the decoupling
            freq = [pair[0] for pair in chain_zip if pair != None]
            if chain_zip[-1]!=None: 
                freq.append(chain_zip[-1][-1])
            freq = Counter(freq)
            return freq
        

        if parallel:
            chunksize = int(corpus_length/self.cpu_count)

            corpus_chunks = ["".join(self.corpus.train[i*chunksize:i*chunksize+chunksize]) for i in range(0,self.cpu_count)]

            with Parallel(n_jobs=self.cpu_count, require='sharedmem') as parallel_pool:
                print("Computing in parallel")
                print("Normalize and jobs data...")
                start = time.time()
                jobs_data = parallel_pool(delayed(pre_process)(chunk,self.normalizer.normalize) for chunk in corpus_chunks)

                pair_len_global = reduce(operator.add,[i[-1] for i in jobs_data])

                # When pair_len_global has more than 1 max, the first encountered is chosen, introducing possible discrepancies between implementations (because each choice modifies global statistics). However, multiple max is less likely to appear in big corpora and relatively small vocabularies, and mostly at the tail of vocabularies (ie. low frequencies), so the impact of this divergence is expected to be marginal.
                best_pair, best_pair_len = max(pair_len_global.items(), key=operator.itemgetter(1))

                merges = [" ".join(best_pair)]
                print(f"... computed in {time.time()-start} secs.\n")

                print("Build alphabet...")
                start = time.time()
                alphabet = Counter()
                for (l,r),v in pair_len_global.items():
                    alphabet[l] += v
                # In extreme cases, right characters of pairs might not be left characters. If there are such chars, they're added with freq 1
                left_out_chars = {r for l,r in pair_len_global.keys()}-alphabet.keys()
                if len(left_out_chars)>0:
                    print(f"Adding characters: {left_out_chars}")
                    for char in left_out_chars:
                        alphabet[char] += 1
                print(f"... computed in {time.time()-start} secs.\n")

                alpha_len = len(alphabet)
                special_tokens_len = 0 if special_tokens == None else len(special_tokens)
                
                print(f"Alphabet Size: {alpha_len}")
                print(f"Special Tokens Size: {special_tokens_len}")

                
                if vocab_size<0:
                    voc_final_length = alpha_len + abs(vocab_size) + special_tokens_len
                else:
                    voc_final_length = vocab_size

                delta_voc = voc_final_length - alpha_len - special_tokens_len

                print(f"Terms to compute: {delta_voc}\n")

                print("Enter loop")

                # for _ in trange(delta_voc):
                t = trange(delta_voc, disable = not progress_bar)
                for _ in t:
                    t.set_description(f"Pair: {best_pair}, {best_pair_len}")
                    t.refresh()

                    jobs_data = parallel_pool(delayed(process_best_pair)(job_data, best_pair) for job_data in jobs_data)

                    if truncate_best_size==None:
                        pair_len_global = reduce(operator.add,[i[-1] for i in jobs_data])
                    else:
                        pair_len_global = reduce(operator.add,[Counter(dict(i[-1].most_common(truncate_best_size))) for i in jobs_data])

                    # When pair_len_global has more than 1 max, the first encountered is chosen, introducing possible discrepancies between implementations (because each choice modifies global statistics). However, multiple max is less likely to appear in big corpora and relatively small vocabularies, and mostly at the tail of vocabularies (ie. low frequencies), so the impact of this divergence is expected to be marginal.
                    best_pair, best_pair_len = max(pair_len_global.items(), key=operator.itemgetter(1))

                    merges.append(" ".join(best_pair))
                
                    if saveQ == True:
                        voc_partial_len = alpha_len + special_tokens_len + _ + 1
                        if voc_partial_len % save_step == 0 and voc_partial_len != voc_final_length:

                            print("Saving intermediate results...")
                            start = time.time()
                            freqs = parallel_pool(delayed(compute_freq)(job_data[0]) for job_data in jobs_data)
                            freq = reduce(operator.add, freqs)

                            vocabulary = freq.most_common()
                            
                            if special_tokens != None:
                                vocabulary = vocabulary + [(token,0) for token in special_tokens]
                            
                            self.merges = merges
                            self.encode = {k:i for i,(k,v) in enumerate(vocabulary)}
                            self.freq = dict(vocabulary)
                            self.alpha = dict(alphabet.most_common())
                            step_path = self.path / str(voc_partial_len)
                            self.save(step_path)
                            print(f"... computed in {time.time()-start} secs.")
                            print(f"Intermediate vocabulary saved to {step_path}\n")

                print("Compute freq...")
                start = time.time()
                freqs = parallel_pool(delayed(compute_freq)(job_data[0]) for job_data in jobs_data)
                freq = reduce(operator.add, freqs)
                print(f"... computed in {time.time()-start} secs.\n")
        
        else:
            print("Computing sequentially")
            print("Normalize and jobs data...")
            start = time.time()
            corpus_chain = "".join(self.corpus.train[:corpus_length])
            job_data = pre_process(corpus_chain,self.normalizer.normalize)

            pair_len_global = job_data[-1]

            # When pair_len_global has more than 1 max, the first encountered is chosen, introducing possible discrepancies between implementations (because each choice modifies global statistics). However, multiple max is less likely to appear in big corpora and relatively small vocabularies, and mostly at the tail of vocabularies (ie. low frequencies), so the impact of this divergence is expected to be marginal.
            best_pair, best_pair_len = max(pair_len_global.items(), key=operator.itemgetter(1))

            merges = [" ".join(best_pair)]
            print(f"... computed in {time.time()-start} secs.\n")

            print("Build alphabet...")
            start = time.time()
            alphabet = Counter()
            for (l,r),v in pair_len_global.items():
                alphabet[l] =+ v
            # In extreme cases, right characters of pairs might not be left characters. If there are such chars, they're added with freq 1
            left_out_chars = {r for l,r in pair_len_global.keys()}-alphabet.keys()
            if len(left_out_chars)>0:
                print(f"Adding characters: {left_out_chars}")
                for char in left_out_chars:
                    alphabet[char] =+ 1
            print(f"... computed in {time.time()-start} secs.\n")

            alpha_len = len(alphabet)
            special_tokens_len = 0 if special_tokens == None else len(special_tokens)
            
            print(f"Alphabet Size: {alpha_len}")
            print(f"Special Tokens Size: {special_tokens_len}")
            
            if vocab_size<0:
                voc_final_length = alpha_len + abs(vocab_size) + special_tokens_len
            else:
                voc_final_length = vocab_size

            delta_voc = voc_final_length - alpha_len - special_tokens_len
            
            print(f"Terms to compute: {delta_voc}\n")

            print("Enter loop")

            # for _ in trange(delta_voc):
            t = trange(delta_voc, disable = not progress_bar)
            for _ in t:
                t.set_description(f"Pair: {best_pair}, {best_pair_len}")
                t.refresh()

                # print(f"{_+1+alpha_len+special_tokens_len}/{voc_final_length}: {best_pair}...")
                # start = time.time()
                job_data = process_best_pair(job_data, best_pair)

                pair_len_global = job_data[-1]

                # When pair_len_global has more than 1 max, the first encountered is chosen, introducing possible discrepancies between implementations (because each choice modifies global statistics). However, multiple max is less likely to appear in big corpora and relatively small vocabularies, and mostly at the tail of vocabularies (ie. low frequencies), so the impact of this divergence is expected to be marginal.
                best_pair, best_pair_len = max(pair_len_global.items(), key=operator.itemgetter(1))

                merges.append(" ".join(best_pair))
                # print(f"... computed in {time.time()-start} secs.\n")

                if saveQ == True:
                    voc_partial_len = alpha_len + special_tokens_len + _ + 1
                    if voc_partial_len % save_step == 0 and voc_partial_len != voc_final_length:

                        print("Saving intermediate results...")
                        start = time.time()
                        freq = compute_freq(job_data[0])

                        vocabulary = freq.most_common()
                        
                        if special_tokens != None:
                            vocabulary = vocabulary + [(token,0) for token in special_tokens]
                        
                        self.merges = merges
                        self.encode = {k:i for i,(k,v) in enumerate(vocabulary)}
                        self.freq = dict(vocabulary)
                        self.alpha = dict(alphabet.most_common())
                        step_path = self.path / str(voc_partial_len)
                        self.save(step_path)
                        print(f"... computed in {time.time()-start} secs.")
                        print(f"Intermediate vocabulary saved to {step_path}\n")
            
            print("Compute freq...")
            start = time.time()
            freq = compute_freq(job_data[0])
            print(f"... computed in {time.time()-start} secs.\n")

        vocabulary = freq.most_common()
        
        if special_tokens != None:
            vocabulary = vocabulary + [(token,0) for token in special_tokens]
        
        self.merges = merges
        self.encode = {k:i for i,(k,v) in enumerate(vocabulary)}
        self.freq = dict(vocabulary)
        self.alpha = dict(alphabet.most_common())

        self.decode = {i:k for k,i in self.encode.items()}
        
        self.len = len(vocabulary)     
        self.freq_mass = sum(self.freq.values())
        self.prob = {k:v/self.freq_mass for k,v in self.freq.items()}

        print("Vocabulary built")
        
        if save == True:
            self.save()
            print(f"Vocabulary saved to {self.path}")