In [1]:
from Data import get_min_descriptionsNorm_triples_relations, LOGGER_FILES, PKLS_FILES
from utils.utils import read_cached_array, cache_array, get_logger
from transformers import BertModel, BertTokenizer, BertTokenizerFast

import torch

from torch.utils.data import Dataset, DataLoader


from math import ceil
from tqdm import tqdm

from collections import defaultdict




  from .autonotebook import tqdm as notebook_tqdm


In [None]:


def batch_tokenize(sentences, tokenizer, batch_size=2056):
    all_batches = []
    all_masks = []
    sentence_tokens = []
    
    total_batches = ceil(len(sentences) / batch_size)
    for i in tqdm(
        range(0, len(sentences), batch_size),
        total=total_batches,
        desc="Tokenizing batches",
        unit="batch"
    ):
        batch = sentences[i : i + batch_size]
        enc = tokenizer(
            batch, 
            return_offsets_mapping=True, 
            add_special_tokens = False 
        )
        
        
        sentence_tokens.extend([tokenizer.convert_ids_to_tokens(sen) for sen in enc["input_ids"]])
    return torch.cat(all_batches, dim=0), torch.cat(all_masks, dim=0) , sentence_tokens



def tokenize_descriptions(sentences_texts):
    
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    descs_tokenized_input_ids, descs_tokenized_attention_masks, sentence_tokens = batch_tokenize(sentences_texts, tokenizer, batch_size=10)
    return descs_tokenized_input_ids, descs_tokenized_attention_masks, sentence_tokens


In [4]:
k = 10 
CHUNK_SIZE = 2
DESCRIPTION_MAX_LENGTH = 128

descs, triples, relations, aliases = get_min_descriptionsNorm_triples_relations(k)
sentences_ids = list(descs.keys())
sentences_texts = list(descs.values())

# descs_tokenized_input_ids, descs_tokenized_attention_masks, sentence_tokens = tokenize_descriptions(sentences_texts)





In [72]:
aliases = {
    "q100": ["football", "football soccer"],
    "q500": ["number 9", "num 0"],
    "q700": ["Roland crystal"],
    "q800": ["ML", "machine learning"],
    "q900": ["italy", "italia"],
    "q901": ["europe", "european union"],
    "q1000": ["soccer football player"]
}
triples= {
    "q2": [ ("q700", "r2", "q800")],
    "q1": [ ("q100" , "r1", "q500"), ("q100", "r2", "q1000")],
    "q3": [("q900", "r3", "q901")]
}

descriptions = {
    "q1": "Raymond Neifel is an indian football soccer player with number 9",
    "q2": "Roland Crystal is the greatest machine learning engineer",
    "q3": "Italia is a country in the european union"
}


In [81]:
import re 
sentences_triples_heads_aliases = [
    [aliases[t[0]] for t in triples[s]] 
    for s in sentences_ids
]

sentences_triples_tails_aliases = [
    [aliases[t[2]] for t in triples[s]] 
    for s in sentences_ids
]



print(f"sentences_triples_heads_aliases : {sentences_triples_heads_aliases}")
print(f"sentences_triples_tails_aliases : {sentences_triples_tails_aliases}")

alias_pattern_map


sentences_triples_heads_aliases : [[['football', 'football soccer'], ['football', 'football soccer']], [['Roland crystal']], [['italy', 'italia']]]
sentences_triples_tails_aliases : [[['number 9', 'num 0'], ['soccer football player']], [['ML', 'machine learning']], [['europe', 'european union']]]


{'football': re.compile(r'football', re.IGNORECASE|re.UNICODE),
 'football soccer': re.compile(r'football\s*soccer', re.IGNORECASE|re.UNICODE),
 'number 9': re.compile(r'number\s*9', re.IGNORECASE|re.UNICODE),
 'num 0': re.compile(r'num\s*0', re.IGNORECASE|re.UNICODE),
 'Roland crystal': re.compile(r'Roland\s*crystal', re.IGNORECASE|re.UNICODE),
 'ML': re.compile(r'ML', re.IGNORECASE|re.UNICODE),
 'machine learning': re.compile(r'machine\s*learning',
            re.IGNORECASE|re.UNICODE),
 'italy': re.compile(r'italy', re.IGNORECASE|re.UNICODE),
 'italia': re.compile(r'italia', re.IGNORECASE|re.UNICODE),
 'europe': re.compile(r'europe', re.IGNORECASE|re.UNICODE),
 'european union': re.compile(r'european\s*union', re.IGNORECASE|re.UNICODE),
 'soccer football player': re.compile(r'soccer\s*football\s*player',
            re.IGNORECASE|re.UNICODE)}

In [85]:
import re 

CHUNK_SIZE = 2
BATCH_SIZE = len(descriptions)
L = 128


tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')


sentences_ids = list(descriptions.keys())
sentences_texts = list(descriptions.values())


sentences_triples_heads_aliases = [
    [aliases[t[0]] for t in triples[s]] 
    for s in sentences_ids
]
sentences_triples_tails_aliases = [
    [aliases[t[2]] for t in triples[s]] 
    for s in sentences_ids
]



alias_pattern_map = {} 
for lst in aliases.values():
    for alias in lst:
        escaped = re.escape(alias)
        flexible  = escaped.replace(r"\ ", r"\s*")
        pattern   = rf"\b{flexible}\b"
        alias_pattern_map[alias] = re.compile(pattern, re.IGNORECASE)




silver_span_head_s = torch.zeros(BATCH_SIZE, L )
silver_span_head_e = torch.zeros(BATCH_SIZE, L )
silver_span_tail_s = torch.zeros(BATCH_SIZE, L )
silver_span_tail_e = torch.zeros(BATCH_SIZE, L )

all_sentences_tokens = []
all_sentences_offsets = []


total_batches = ceil(len(sentences_texts) / CHUNK_SIZE)
for i in tqdm(
    range(0, len(sentences_texts), CHUNK_SIZE),
    total=total_batches,
    desc="Tokenizing batches",
    unit="batch"
):

    batch = sentences_texts[i : i + CHUNK_SIZE]
    enc = tokenizer(
        batch, 
        return_offsets_mapping=True, 
        add_special_tokens = False 
    )
    all_sentences_offsets.extend(enc.offset_mapping)
    
    for sen_idx, enc_obj in enumerate(enc.encodings):
        all_sentences_tokens.append(enc_obj.tokens)
        
        sentence_idx_in_batch = i + sen_idx
        current_description = sentences_texts[sentence_idx_in_batch]
        sentence_heads_aliases = sentences_triples_heads_aliases[sentence_idx_in_batch]
        sentence_tails_aliases = sentences_triples_tails_aliases[sentence_idx_in_batch]
        sentence_tokens_offset = all_sentences_offsets[sentence_idx_in_batch]
        
        for one_als_list in sentence_heads_aliases:
            for als_str in one_als_list:
                pattern = alias_pattern_map[als_str]  
                m = pattern.search(current_description)
                if not m: continue 
                start_char, end_char = m.span()
                token_indices = [
                    i for i, (s, e) in enumerate(sentence_tokens_offset)
                    if (s < end_char) and (e > start_char)
                ]
                head_start, head_end = token_indices[0], token_indices[-1]
                silver_span_head_s[sentence_idx_in_batch, head_start] = 1
                silver_span_head_e[sentence_idx_in_batch, head_end] = 1
                break
        
        for one_als_list in sentence_tails_aliases:
            for als_str in one_als_list:
                pattern = alias_pattern_map[als_str]  
                m = pattern.search(current_description)
                if not m: continue 
                start_char, end_char = m.span()
                token_indices = [
                    i for i, (s, e) in enumerate(sentence_tokens_offset)
                    if (s < end_char) and (e > start_char)
                ]
                tail_start, tail_end = token_indices[0], token_indices[-1]
                silver_span_tail_s[sentence_idx_in_batch, tail_start] = 1
                silver_span_tail_e[sentence_idx_in_batch, tail_end] = 1
                break
        


Tokenizing batches: 100%|██████████| 2/2 [00:00<00:00, 1328.78batch/s]


In [83]:
for sen_idx in range(BATCH_SIZE):
    head_starts = silver_span_head_s[sen_idx]
    head_ends = silver_span_head_e[sen_idx]
    sen_tokens = all_sentences_tokens[sen_idx]
    
    head_starts_idxs = torch.nonzero(head_starts == 1, as_tuple=True)[0]
    head_ends_idxs = torch.nonzero(head_ends == 1, as_tuple=True)[0]
    
    print(f"sentence: " , sentences_texts[sen_idx])
    for h_s, h_e in zip(head_starts_idxs, head_ends_idxs):
        print(f"\t HEAD: {sen_tokens[h_s: h_e + 1]}")
        

for sen_idx in range(BATCH_SIZE):
    tail_starts = silver_span_tail_s[sen_idx]
    tail_ends = silver_span_tail_e[sen_idx]
    sen_tokens = all_sentences_tokens[sen_idx]
    
    tail_starts_idxs = torch.nonzero(tail_starts == 1, as_tuple=True)[0]
    tail_ends_idxs = torch.nonzero(tail_ends == 1, as_tuple=True)[0]
    
    print(f"sentence: " , sentences_texts[sen_idx])
    for t_s, t_e in zip(tail_starts_idxs, tail_ends_idxs):
        print(f"\t TAIL: {sen_tokens[t_s: t_e + 1]}")
        


sentence:  Raymond Neifel is an indian football soccer player with number 9
	 HEAD: ['football']
sentence:  Roland Crystal is the greatest machine learning engineer
	 HEAD: ['Roland', 'Crystal']
sentence:  Italia is a country in the european union
	 HEAD: ['Italia']
sentence:  Raymond Neifel is an indian football soccer player with number 9
	 TAIL: ['number', '9']
sentence:  Roland Crystal is the greatest machine learning engineer
	 TAIL: ['machine', 'learning']
sentence:  Italia is a country in the european union
	 TAIL: ['euro', '##pe']


## CHECK MY SILVER SPANS THAT I CREATED

In [4]:
silver_spans = read_cached_array(PKLS_FILES["silver_spans"][1000])

silver_span_head_s = silver_spans["head_start"]
silver_span_head_e = silver_spans["head_end"]
silver_span_tail_s = silver_spans["tail_start"]
silver_span_tail_e = silver_spans["tail_end"]

In [17]:
print(f"all count: {silver_span_head_s.shape[0]}")
has_triples = 0
for idx, (h_s, h_e, t_s, t_e) in enumerate(zip(silver_span_head_s, silver_span_head_e, silver_span_tail_s, silver_span_tail_e)):
    idxs_1 = torch.nonzero(h_s == 1, as_tuple=True)[0]
    idxs_2 = torch.nonzero(h_e == 1, as_tuple=True)[0]
    idxs_3 = torch.nonzero(t_s == 1, as_tuple=True)[0]
    idxs_4 = torch.nonzero(t_s == 1, as_tuple=True)[0]
    if len(idxs_1) == 0 or len(idxs_2) == 0 or len(idxs_3) == 0 or len(idxs_4) == 0: continue
    has_triples += 1
    
print(f"has triples count: {has_triples}")

all count: 3364
has triples count: 575


In [37]:
idx = 4
k = 1000
descs, triples, _, aliases = get_min_descriptionsNorm_triples_relations(k)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')

sentences_ids = list(descs.keys())
sentences_texts = list(descs.values())

my_sentence_id = sentences_ids[idx]
my_sentence_text = sentences_texts[idx]
my_sentence_triples = triples[my_sentence_id  ]
my_sentence_triples_str = [
    {"head": aliases[trp[0]], "tail": aliases[trp[2]]}
    for trp in my_sentence_triples
]
enc = tokenizer(
            my_sentence_text, 
            return_offsets_mapping=True,
            add_special_tokens = False,
            padding="max_length", 
            truncation=True,
            max_length=128
            
        )


my_sentence_tokens = enc[0].tokens




print(f"my sentence text: {my_sentence_text}" )
print(f"my_sentence_tokens: {my_sentence_tokens}" )
print("sentence triples: ")
for tt in my_sentence_triples_str:
    print(f"\t NEW TRIPLE")
    print(f"\t head: {tt['head']}")
    print(f"\t tail: {tt['tail']}")

print("silver spans")
head_start_idxs = torch.nonzero(silver_span_head_s[idx] == 1, as_tuple=True)[0]
head_end_idxs = torch.nonzero(silver_span_head_e[idx] == 1, as_tuple=True)[0]
tail_start_idxs = torch.nonzero(silver_span_tail_s[idx] == 1, as_tuple=True)[0]
tail_end_idxs = torch.nonzero(silver_span_tail_e[idx] == 1, as_tuple=True)[0]
print(tail_start_idxs)
print(tail_end_idxs)
used_ends = set()
for h_s in head_start_idxs:
    for h_e in head_end_idxs:
        if h_e.item() not in used_ends:
            used_ends.add(h_e.item())
            print(f"HEAD:  {my_sentence_tokens[h_s: h_e + 1 ]} ")
            
            break


used_ends = set()
for h_s in tail_start_idxs:
    for h_e in tail_end_idxs:
        if h_e.item() not in used_ends:
            used_ends.add(h_e.item())
            found_tail = True 
            print(f"TAIL:  {my_sentence_tokens[h_s: h_e + 1 ]} ")
            
            break        



my sentence text: Good Deal with Dave Lieberman is a television cooking show hosted by Dave Lieberman that airs on the Food Network in the United States and Food Network Canada in Canada. The show premiered on Food Network on April 16, 2005. Lieberman's show presents affordable gourmet quality recipes.
my_sentence_tokens: ['Good', 'Deal', 'with', 'Dave', 'Lie', '##berman', 'is', 'a', 'television', 'cooking', 'show', 'hosted', 'by', 'Dave', 'Lie', '##berman', 'that', 'airs', 'on', 'the', 'Food', 'Network', 'in', 'the', 'United', 'States', 'and', 'Food', 'Network', 'Canada', 'in', 'Canada', '.', 'The', 'show', 'premiered', 'on', 'Food', 'Network', 'on', 'April', '16', ',', '2005', '.', 'Lie', '##berman', "'", 's', 'show', 'presents', 'affordable', 'go', '##ur', '##met', 'quality', 'recipes', '.', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD

## SIMULATE MY MODEL AND COMPARE 

In [52]:
import torch 

idx = 4
silver_spans = read_cached_array(PKLS_FILES["silver_spans"][1000])

silver_span_head_s = silver_spans["head_start"]
silver_span_head_e = silver_spans["head_end"]
silver_span_tail_s = silver_spans["tail_start"]
silver_span_tail_e = silver_spans["tail_end"]


head_start_idxs = torch.nonzero(silver_span_head_s[idx] == 1, as_tuple=True)[0]
head_end_idxs = torch.nonzero(silver_span_head_e[idx] == 1, as_tuple=True)[0]
tail_start_idxs = torch.nonzero(silver_span_tail_s[idx] == 1, as_tuple=True)[0]
tail_end_idxs = torch.nonzero(silver_span_tail_e[idx] == 1, as_tuple=True)[0]

tail_end_idxs


tensor([21, 25])

In [81]:
import torch 
from utils.model_helpers import extract_triples

sentences = [
    "Good Deal with Dave Lieberman is a television cooking show hosted by Dave Lieberman that airs on the Food Network in the United States and Food Network Canada in Canada. The show premiered on Food Network on April 16, 2005. Lieberman's show presents affordable gourmet quality recipes."
]
L = 128 
B = 1
R = 1
#B, L
subj_idxs = [
    [(0,5)]
]

subj_start_probs = torch.zeros((B ,L)) 
subj_start_probs[0, 0 ] = 0.9

subj_end_probs = torch.zeros((B ,L))
subj_end_probs[0, 5] = 0.9

f_obj_start_probs = torch.zeros((B,L))
f_obj_end_probs = torch.zeros((B,L))
f_obj_start_probs[0, 19 ] = 0.9
f_obj_start_probs[0, 24 ] = 0.9
f_obj_start_probs[0, 30 ] = 0.9

f_obj_end_probs[0, 21 ] = 0.9
f_obj_end_probs[0, 25 ] = 0.9
f_obj_end_probs[0, 30 ] = 0.9

torch.nonzero(f_obj_start_probs[0] > 0, as_tuple=True)[0]


extracted_triples = extract_triples(subj_idxs, f_obj_start_probs.unsqueeze(-1), f_obj_end_probs.unsqueeze(-1),  True , threshold=.5)
# B , NUM_TRIPLES
len(extracted_triples[0])
extracted_triples

[[((0, 5), 0, (19, 21)), ((0, 5), 0, (24, 25)), ((0, 5), 0, (30, 30))]]

In [105]:

def clean_descriptions_dict(descriptions_dict, silver_spans):
    all_keys = list(descriptions_dict.keys())
    silver_span_head_s = silver_spans["head_start"]
    silver_span_head_e = silver_spans["head_end"]
    silver_span_tail_s = silver_spans["tail_start"]
    silver_span_tail_e = silver_spans["tail_end"]
    
    m1 = silver_span_head_s .any(dim=1) 
    m2 = silver_span_head_e .any(dim=1)
    m3 = silver_span_tail_s .any(dim=1) 
    m4 = silver_span_tail_e .any(dim=1)
    mask = m1 & m2 & m3 & m4
    indexes    = torch.nonzero(mask, as_tuple=True)[0]
    cleaned_desc_dict= {k: descriptions_dict[k]  for idx, k in enumerate(all_keys) if idx in indexes  }
    return cleaned_desc_dict, silver_span_head_s[indexes],silver_span_head_e[indexes],silver_span_tail_s[indexes], silver_span_tail_e[indexes]
            


In [111]:
from utils.model_helpers import get_h_gs, extract_first_embeddings,extract_last_idxs, extract_triples,  merge_triples


class BRASKDataSet(Dataset):
    def __init__(self, descriptions_dict, silver_spans , desc_max_length=4):
        print("Initiating dataset.. ")
        cleaned_descriptions, silver_span_head_s, silver_span_head_e, silver_span_tail_s, silver_span_tail_e = descriptions_dict, silver_spans["head_start"], silver_spans["head_end"], silver_spans["tail_start"], silver_spans["tail_end"]
        
        #silver_spans should be dictionary having keys head_start, head_end, tail_start, tail_end, each one is tensor with shape (B, seq_len) 
        valid = (len(cleaned_descriptions), desc_max_length) == silver_span_head_s.shape == silver_span_tail_s.shape == silver_span_tail_e.shape==silver_span_head_e.shape 
        assert valid 
        if valid:
            print("\tvalid")
            tokenizer =BertTokenizer.from_pretrained('bert-base-cased')
            model = BertModel.from_pretrained('bert-base-cased')
            print("\tcreating clean descriptions")
            
            print(f"\twe have {len(cleaned_descriptions)} descriptions")
            
            sentences = list(cleaned_descriptions.values())
            print("\tcreating h_gs")
            self.h_gs, self.embs = get_h_gs(sentences, tokenizer, model, max_length=desc_max_length  )  #  h_gs (batch_size, hidden_size), embs (batch_size, seq_len, hidden_size)
            print(f"\tself embs shape: {self.embs.shape} should be ({len(cleaned_descriptions)}, {desc_max_length},hidden_size )")
            assert self.embs.shape[0] == len(cleaned_descriptions)
            assert self.embs.shape[1] == desc_max_length
            self.labels_head_start, self.labels_head_end, self.labels_tail_start, self.labels_tail_end =  silver_span_head_s, silver_span_head_e, silver_span_tail_s, silver_span_tail_e 

    def __getitem__(self,idx):
        return  {
            "h_gs": self.h_gs[idx], 
            "embs": self.embs[idx],
            "labels_head_start": self.labels_head_start[idx] ,
            "labels_head_end": self.labels_head_end[idx],
            "labels_tail_start": self.labels_tail_start[idx],
            "labels_tail_end": self.labels_tail_end[idx],

        }


    def __len__(self):
        return self.h_gs.shape[0]
    
    def save(self, path):
        di = {
            "h_gs": self.h_gs.cpu(),
            "embs": self.embs.cpu(),
            "labels_head_start": self.labels_head_start.cpu() ,
            "labels_head_end": self.labels_head_end.cpu(),
            "labels_tail_start":  self.labels_tail_start.cpu(),
            "labels_tail_end":  self.labels_tail_end.cpu(),
            
        }
        cache_array(di, path)
    @classmethod
    def load(cls, path):
        print("loadding dataset from cache.. ")
        data = read_cached_array(path)

        dataset = cls.__new__(cls)
        dataset.h_gs = data["h_gs"]
        print(f"\t dataset.h_gs.shape: {dataset.h_gs.shape}")
        dataset.embs = data["embs"]
        print(f"\t dataset.embs.shape: {dataset.embs.shape}")
        dataset.labels_head_start = data["labels_head_start"]
        dataset.labels_head_end = data["labels_head_end"]
        dataset.labels_tail_start = data["labels_tail_start"]
        dataset.labels_tail_end = data["labels_tail_end"]
        
        return dataset
    


In [112]:
descriptions_dict = {
    "q1": "I am q1 in the field",
    "q2": "I am q2 fsd dsaonv noifd", 
    "q3": "I am q3 hello should be deleted",
    "q4": "I am q4 here but not exists",
}

silver_spans = {
    "head_start" : torch.tensor(
        [
            [1,1,1,1], 
            [0,0,0,0], 
            [1,0,0,0], 
            [0,0,0,1],
        ]
    ),
    "head_end" : torch.tensor(
        [
            [1,1,1,1], 
            [0,0,0,0], 
            [1,0,0,0], 
            [0,0,0,1],
        ]
    ),
      "tail_start" : torch.tensor(
        [
            [1,1,1,1], 
            [0,0,0,0], 
            [1,0,0,0], 
            [0,0,0,1],
        ]
    ),
    "tail_end" : torch.tensor(
        [
            [1,1,1,1], 
            [0,0,0,0], 
            [1,0,0,0], 
            [0,0,0,1],
        ]
    ),
}


dataset = BRASKDataSet(descriptions_dict,silver_spans )

Initiating dataset.. 
	valid
	creating clean descriptions
	we have 4 descriptions
	creating h_gs
	self embs shape: torch.Size([4, 4, 768]) should be (4, 4,hidden_size )


In [113]:

label_keys = [
    'labels_head_start', 'labels_head_end',
    'labels_tail_start','labels_tail_end'
]
pos_counts = {k: 0 for k in label_keys}
neg_counts = {k: 0 for k in label_keys}

loader = DataLoader(dataset,
                batch_size=1,
                num_workers=0,
                shuffle=False)

for batch in loader:
    for k in label_keys:
        lbl = batch[k]
        p = lbl.sum().item() 
        n = lbl.numel() - p 
        pos_counts[k] += p
        neg_counts[k] += n
{
    k: (neg_counts[k] / pos_counts[k]) if pos_counts[k] > 0 else 1.0
    for k in label_keys
    
}

{'labels_head_start': 1.6666666666666667,
 'labels_head_end': 1.6666666666666667,
 'labels_tail_start': 1.6666666666666667,
 'labels_tail_end': 1.6666666666666667}

In [8]:
import torch 


#chatgpt 4
def clean_descriptions_dict_opt(descriptions_dict, silver_spans):
    print(f"\t cleaning descriptions_dict with size {len(descriptions_dict)} ")
    all_keys = list(descriptions_dict.keys())
    silver_span_head_s = silver_spans["head_start"]
    silver_span_head_e = silver_spans["head_end"]
    silver_span_tail_s = silver_spans["tail_start"]
    silver_span_tail_e = silver_spans["tail_end"]
    print(f"\t and silver spans with shape {silver_span_tail_e.shape}")
    
    
    m1 = silver_span_head_s .any(dim=1) 
    m2 = silver_span_head_e .any(dim=1)
    m3 = silver_span_tail_s .any(dim=1) 
    m4 = silver_span_tail_e .any(dim=1)
    mask_prev = m1 & m2 & m3 & m4
    print(f"mask before: {mask_prev}")
    
    mask = (
        silver_span_head_s .any(dim=1) &
        silver_span_head_e .any(dim=1) &
        silver_span_tail_s .any(dim=1) &
        silver_span_tail_e .any(dim=1)
    )
    print(f"mask now: {mask}")
    #wow! I did not know that I can do this
    filtered_dict = {
        key: value
        for (key, value), keep in zip(descriptions_dict.items(), mask)
        if keep.item()
    }
    idx = mask.nonzero(as_tuple=True)[0] 
    return (
        filtered_dict,
        silver_span_head_s[idx],
        silver_span_head_e[idx],
        silver_span_tail_s[idx],
        silver_span_tail_e[idx],
    )


descriptions_dict = {
    "q1": "I am q1 in the field",
    "q2": "I am q2 fsd dsaonv noifd", 
    "q3": "I am q3 hello should be deleted",
    "q4": "I am q4 here but not exists",
}

silver_spans = {
    "head_start" : torch.tensor(
        [
            [1,1,1,1], 
            [0,0,0,0], 
            [1,0,0,0], 
            [0,0,0,1],
        ]
    ),
    "head_end" : torch.tensor(
        [
            [1,1,1,1], 
            [0,0,0,0], 
            [1,0,0,0], 
            [0,0,0,1],
        ]
    ),
      "tail_start" : torch.tensor(
        [
            [1,1,1,1], 
            [0,0,0,0], 
            [1,0,0,0], 
            [0,0,0,1],
        ]
    ),
    "tail_end" : torch.tensor(
        [
            [1,1,1,1], 
            [0,0,0,0], 
            [1,0,0,0], 
            [0,0,0,1],
        ]
    ),
}




clean_descriptions_dict_opt(descriptions_dict, silver_spans)

	 cleaning descriptions_dict with size 4 
	 and silver spans with shape torch.Size([4, 4])
mask before: tensor([ True, False,  True,  True])
mask now: tensor([ True, False,  True,  True])


({'q1': 'I am q1 in the field',
  'q3': 'I am q3 hello should be deleted',
  'q4': 'I am q4 here but not exists'},
 tensor([[1, 1, 1, 1],
         [1, 0, 0, 0],
         [0, 0, 0, 1]]),
 tensor([[1, 1, 1, 1],
         [1, 0, 0, 0],
         [0, 0, 0, 1]]),
 tensor([[1, 1, 1, 1],
         [1, 0, 0, 0],
         [0, 0, 0, 1]]),
 tensor([[1, 1, 1, 1],
         [1, 0, 0, 0],
         [0, 0, 0, 1]]))

In [22]:
from datasets import Dataset
from transformers import BertTokenizerFast, BertModel

tokenizer =BertTokenizerFast.from_pretrained('bert-base-cased')


dataset = Dataset.from_dict({"text": list(descriptions_dict.values())})
def tokenize_function(batch):
    return tokenizer(
        batch["text"],
        padding="max_length",
        truncation=True,
        return_tensors="pt",
        max_length=128
    )
encoded = dataset.map(
    tokenize_function,
    batched=True,
    num_proc=1
)

Map: 100%|██████████| 4/4 [00:00<00:00, 192.94 examples/s]


In [23]:
import torch 
input_ids = torch.tensor(encoded['input_ids'])
attention_mask = torch.tensor(encoded['attention_mask'])
model = BertModel.from_pretrained('bert-base-cased')

with torch.no_grad():
    bert_output = model(input_ids=input_ids, attention_mask=attention_mask)
    embeddings = bert_output.last_hidden_state # (batch_size, seq_len ,hidden_size)
embeddings.shape

torch.Size([4, 128, 768])

In [45]:
import re 


descriptions = {
    "q1": "Europe is a continent of the world",
    "q2": "Italy is a country of Europe"
}

aliases = {
    "q1": ["europe", "oropa", "eu"],
    "q2": ["italia", "italy", "it"],
}
triples = {
    "q1": [],
    "q2": [ ("q2", "r1", "q1")]
}

BATCH_SIZE = len(descriptions)
sentences_texts = list(descriptions.values())



def extract_silver_spans(descs, triples, aliases):
    CHUNK_SIZE = 1 
    L = DESCRIPTION_MAX_LENGTH = 6 
    sentences_ids = list(descs.keys())
    sentences_texts = list(descs.values())
    sentences_triples_heads_aliases = [
        [aliases[t[0]] for t in triples[s]] 
        for s in sentences_ids
    ]
    sentences_triples_tails_aliases = [
        [aliases[t[2]] for t in triples[s]] 
        for s in sentences_ids
    ]
    
    alias_pattern_map = {} 
    for lst in aliases.values():
        for alias in lst:
            escaped = re.escape(alias)
            flexible  = escaped.replace(r"\ ", r"\s*")
            pattern   = rf"\b{flexible}\b"
            alias_pattern_map[alias] = re.compile(pattern, re.IGNORECASE)


    silver_span_head_s = torch.zeros(BATCH_SIZE, L )
    silver_span_head_e = torch.zeros(BATCH_SIZE, L )
    silver_span_tail_s = torch.zeros(BATCH_SIZE, L )
    silver_span_tail_e = torch.zeros(BATCH_SIZE, L )

    all_sentences_tokens = []
    all_sentences_offsets = []

    total_batches = int(len(sentences_texts) / CHUNK_SIZE)
    for i in range(0, len(sentences_texts), CHUNK_SIZE):

        batch = sentences_texts[i : i + CHUNK_SIZE]
        enc = tokenizer(
            batch, 
            return_offsets_mapping=True,
            add_special_tokens = False,
            padding="max_length", 
            truncation=True,
            max_length=DESCRIPTION_MAX_LENGTH
            
        )
        all_sentences_offsets.extend(enc.offset_mapping)

        for sen_idx, enc_obj in enumerate(enc.encodings):
            all_sentences_tokens.append(enc_obj.tokens)

            sentence_idx_in_batch = i + sen_idx
            current_description = sentences_texts[sentence_idx_in_batch]
            sentence_heads_aliases = sentences_triples_heads_aliases[sentence_idx_in_batch]
            sentence_tails_aliases = sentences_triples_tails_aliases[sentence_idx_in_batch]
            sentence_tokens_offset = all_sentences_offsets[sentence_idx_in_batch]

            for one_als_list in sentence_heads_aliases:
                for als_str in one_als_list:
                    pattern = alias_pattern_map[als_str]
                    m = pattern.search(current_description)
                    if not m: continue 
                    start_char, end_char = m.span()
                    token_indices = [
                        i for i, (s, e) in enumerate(sentence_tokens_offset)
                        if (s < end_char) and (e > start_char)
                    ]
                    if len(token_indices) > 0:
                        head_start, head_end = token_indices[0], token_indices[-1]
                        silver_span_head_s[sentence_idx_in_batch, head_start] = 1
                        silver_span_head_e[sentence_idx_in_batch, head_end] = 1
                        break
            

            for one_als_list in sentence_tails_aliases:
                for als_str in one_als_list:
                    pattern =  alias_pattern_map[als_str]
                    m = pattern.search(current_description)
                    if not m: continue 
                    start_char, end_char = m.span()
                    token_indices = [
                        i for i, (s, e) in enumerate(sentence_tokens_offset)
                        if (s < end_char) and (e > start_char)
                    ]
                    if len(token_indices) > 0 :
                        tail_start, tail_end = token_indices[0], token_indices[-1]
                        silver_span_tail_s[sentence_idx_in_batch, tail_start] = 1
                        silver_span_tail_e[sentence_idx_in_batch, tail_end] = 1
                        break
    
    return  silver_span_head_s, silver_span_head_e,  silver_span_tail_s,silver_span_tail_e, all_sentences_tokens

silver_span_head_s,silver_span_head_e,silver_span_tail_s,silver_span_tail_e, all_sentences_tokens  = extract_silver_spans(descriptions, triples, aliases)

print(f"all_sentences_tokens: {all_sentences_tokens}")
print(f"silver_span_head_s: {silver_span_head_s}")
print(f"silver_span_head_e: {silver_span_head_e}")
print(f"silver_span_tail_s: {silver_span_tail_s}")
print(f"silver_span_tail_e: {silver_span_tail_e}")

all_sentences_tokens: [['Europe', 'is', 'a', 'continent', 'of', 'the'], ['Italy', 'is', 'a', 'country', 'of', 'Europe']]
silver_span_head_s: tensor([[0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.]])
silver_span_head_e: tensor([[0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.]])
silver_span_tail_s: tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1.]])
silver_span_tail_e: tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1.]])
