In [None]:
import numpy as np
from functools import cache


def _text_sub_emb_pos(txt, tokenizer):
    emb_pos = []
    for i in range(1, len(txt)+1):
        et = tokenizer.encode(txt[:i])
        emb_pos.append(list(zip(et, list(range(len(et))))))
    return emb_pos

def _create_subs_emb_mask(emb_pos):
    total_len = len(emb_pos)
    emb = []
    mask = np.zeros([total_len, total_len])
    for embt, m in zip(emb_pos, mask):
        emb.append(embt[-1])
        m[len(emb)-1] = 1
        for ie in embt[:-1]:
            for j, je in enumerate(emb[:-1]):
                if je == ie:
                    m[j] = 1
                    break
    return np.array(emb), mask

def _create_emb_pos_mask(txt, tokenizer):
    emb_pos = _text_sub_emb_pos(txt, tokenizer)
    res, mask = _create_subs_emb_mask(emb_pos)    
    return res[:, 0], res[:, 1], mask

def add_masks(o_mask_size, masks):
    iy = 0
    ix = 0
    o_mask = np.zeros((o_mask_size, o_mask_size))

    for m in masks:
        ms = len(m)
        o_mask[iy:iy+ms, ix:ix+ms] += m
        o_mask[iy+ms:, ix:ix+ms] = m[-1]
        iy+=ms
        ix+=ms
    return o_mask

@cache
def _process_token(t, tokenizer):
    td = tokenizer.decode(t)
    if len(td)==1 or "�" in td:
        emb = np.array([t]) 
        pos = np.zeros((1))
        mask = np.ones((1,1))
    else:
        emb, pos, mask = _create_emb_pos_mask(td, tokenizer)
    return emb, pos, mask

def create_token_atomizer(tokenizer) -> dict[int, tuple[np.array, np.array, np.array]]:
    """
    Based on given tokenizer returns atomizer for its tokens (cached alternative, _process_token for every token),
    {
        token_id:{
            embedding, 
            positiolan_encoding_positions, 
            attention_mask
        }
    }
    contains each token informations: its atom-tokens, positional encoding positions and attantion mask <- all in tokenizex processing format 

    :param _type_ tokenizer: _description_
    :return dict[int, tuple[np.array, np.array, np.array]]: aka. atomizer, 
    """
    tokenizex_atomizer = {}
    vocab = {v:k for k, v in tokenizer.get_vocab().items()}
    all_tok = vocab.keys()

    for t in all_tok:
        e, p, m = _process_token(t, tokenizer)
        tokenizex_atomizer[t] = (e, p, m)
    
    return tokenizex_atomizer
        


def embedding_mask(txt, tokenizer, tokenizex_atomizer): #v2
    global st
    tokens = tokenizer.encode(txt)
    partial_masks = []

    res_emb = []
    res_pos = []
    pos_len = 0

    
    for t in tokens:
        emb, pos, mask = tokenizex_atomizer[t]
        res_emb.extend(emb)
        res_pos.extend(pos+pos_len)
        pos_len += pos[-1] + 1
        partial_masks.append(mask)
                
    
    full_mask = add_masks(len(res_emb), partial_masks)
    
    return np.array(res_emb), res_pos, full_mask