In [26]:

from math import ceil
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence

from torch.utils.data import Dataset, DataLoader
from datasets import Dataset as HFDataset
from transformers import BertModel, BertTokenizer, BertTokenizerFast
import numpy as np 
from tqdm import tqdm
import pickle
import torch.multiprocessing as mp
import random
import os
from multiprocessing import Pool, cpu_count



from collections import defaultdict
import itertools
import random 
import logging

from torch.amp import autocast, GradScaler
import re 
from joblib import Parallel, delayed
from joblib import dump, load

import concurrent.futures
import multiprocessing
from functools import partial
import unicodedata

import nltk
from nltk.corpus import words

import torch 
import os 
from transformers import BertTokenizerFast


from datasets import Dataset as HFDataset 

root = "."
RAW_FOLDER = f"{root}/data/raw"
RESULTS_FOLDER = f"{root}/data/results"
HELPERS_FOLDER = f"{root}/data/helpers"
CHECKPOINT_FOLDER = f"{root}/data/checkpoints"
TEMP_FOLDER = f"{root}/data/temp"

RESULT_FILES  = {
    "descriptions_unormalized": f"{RESULTS_FOLDER}/descriptions_unormalized.pkl",
    "descriptions": f"{RESULTS_FOLDER}/descriptions.pkl",
    "aliases": f"{RESULTS_FOLDER}/aliases.pkl",
    "alias_patterns": f"{RESULTS_FOLDER}/aliases_patterns.pkl",
    "relations": f"{RESULTS_FOLDER}/relations.pkl",
    "triples": f"{RESULTS_FOLDER}/triples.pkl",
    "transE_relation_embeddings": f"{RESULTS_FOLDER}/transE_rel_embs.pkl",
    "silver_spans": {
        "head_start": f"{RESULTS_FOLDER}/ss_head_start.npz",
        "head_end": f"{RESULTS_FOLDER}/ss_head_end.npz",
        "tail_start": f"{RESULTS_FOLDER}/ss_tail_start.npz",
        "tail_end": f"{RESULTS_FOLDER}/ss_tail_end.npz",
        "sentence_tokens": f"{RESULTS_FOLDER}/ss_sentence_tokens.pkl",
        "desc_ids": f"{RESULTS_FOLDER}/desc_ids.pkl",
        
    }
}

CHECKPOINTS_FILES = {
    "transe_triples": f"{CHECKPOINT_FOLDER}/transe_triples.pkl",
    "transe_model": f"{CHECKPOINT_FOLDER}/transe_model.pth",
}

TEMP_FILES = {
    "heads_aliases": f"{TEMP_FOLDER}/heads_aliases.pkl",
    "tails_aliases": f"{TEMP_FOLDER}/tails_aliases.pkl",
    "aliases_patterns": f"{TEMP_FOLDER}/aliases_patterns.pkl",
    "sentences_tokens": f"{TEMP_FOLDER}/sentences_tokens.pkl",
    "results_spans": f"{TEMP_FOLDER}/results_spans.pkl",
}

def cache_array(ar, filename):
    with open(filename, 'wb') as f:
        pickle.dump(ar, f)
    print(f"Array chached in file {filename}")

def read_cached_array(filename):
    with open(filename, 'rb', buffering=16*1024*1024) as f:
        return pickle.load(f)

def save_tensor(tensor, path):
    os.makedirs(os.path.dirname(path) , exist_ok=True)
    np.savez_compressed(path, arr=tensor.cpu().numpy())
    print(f"Tensor chached in file {path}")



def read_tensor(path):
    print(f"reading from path {path}")
    loaded = np.load(path)
    return torch.from_numpy(loaded["arr"])


def split_list(data, num_chunks):
    chunk_size = (len(data) + num_chunks - 1) // num_chunks  # ceiling division
    return [data[i * chunk_size : (i + 1) * chunk_size] for i in range(num_chunks)]




In [6]:
def tokenize_function(batch, tokenizer, max_length):
    return tokenizer(
        batch["text"],
        padding="max_length",
        truncation=True, 
        return_tensors="pt",
        max_length = max_length
    )

In [2]:
descriptions = read_cached_array(RESULT_FILES["descriptions"])

sentences = list(descriptions.values())


In [3]:
dataset = HFDataset.from_dict({"text": sentences})


In [4]:
encoded = dataset.map(
    tokenize_function,
    batched=True,
)

Dataset({
    features: ['text'],
    num_rows: 736
})

In [25]:

attention_masks = torch.tensor([
    [1,1,0,0,0],
    [1,1,1,0,0],
]).unsqueeze(-1)
attention_masks.shape  #(2,5,1)

embs = torch.tensor([
    [ [1,2,3,1,2], [2,2,2,2,2], [3,3,3,3,3], [3,2,3,1,4], [5,4,1,5,6]  ],
    [ [1,1,1,1,1], [8,8,7,7,10], [20,40,20,60,20], [30,70,35,25,40], [50,40,10,50,60]  ],
])

# sum_embs = (embs * attention_masks).sum(dim=1)
# token_counts = attention_masks.sum(dim=1).clamp(min=1)

# print(sum_embs)
# print(token_counts)
# sum_embs / token_counts

embs.shape[1]


5

In [88]:


def extract_first_embeddings(token_embs, start_probs, end_probs, max_len, threshold=.6):
    """
        args:
            token_embs: sentence embeddings with shape (batch_size, seq_length, hidden_size)
            start_probs: probabilities that an entity is a start (start of subject for forward, start of object for backward) with shape (batch_size, seq_len)
            end_probs: probabilities that an entity is an end (end of subject for forward, end of object for backward) with shape (batch_size, seq_len)
            threshold: threshold that if the probability > threshold, then it is considered start or end of starting entity
        returns:
            padded_embs: these are padded embeddings of all subjects or objects with shape (B, max_ents, H)
            mask_embs: mask to show the padded embs (0 for padding)
            head_idxs: list (len=batch_size) where each item is a list of tuples, each tuple (start_idx, end_idx) of the entities
    """
    batch_size, seq_len, hidden_size = token_embs.shape

    start_mask = start_probs > threshold
    end_mask = end_probs > threshold

    head_idxs = []
    all_ents_embs = []
    all_masks = []

    for sent_idx in range(batch_size):
        print()
        print(f"sent idx: {sent_idx}")
        start_indices = torch.where(start_mask[sent_idx])[0]
        end_indices = torch.where(end_mask[sent_idx])[0]

        ents_embs = []
        idxs_sentence = []
        used_ends = set()

        for start_idx in start_indices.tolist():
            print(f"    start idx: {start_idx}")
            end_ptr = 0 
            while end_ptr < len(end_indices) and (  end_indices[end_ptr].item() < start_idx or end_indices[end_ptr].item() in used_ends) :
                print(f"        end ptr: {end_ptr}")
                end_ptr += 1
            if end_ptr < len(end_indices):
                end = end_indices[end_ptr].item()
                used_ends.add(end)
                idxs_sentence.append((start_idx, end))

                #compute average embedding (in paper shows ((emb of start + emb of end) / 2), I will do between them )
                sum = token_embs[sent_idx , start_idx : end + 1].sum(dim=0)
                dominator = end + 1 - start_idx
                ent_emb = sum / dominator
                ents_embs.append(ent_emb)

        head_idxs.append(idxs_sentence)

        if ents_embs:
            ent_tensor = torch.stack(ents_embs)
            all_ents_embs.append(ent_tensor)
            all_masks.append(torch.ones(ent_tensor.size(0), dtype=torch.bool, device=token_embs.device))
        else:
            all_ents_embs.append(torch.empty(0, hidden_size, device=token_embs.device))
            all_masks.append(torch.zeros(0, dtype=torch.bool, device=token_embs.device))

    # Pad sequences
    padded_embs = pad_sequence(all_ents_embs, batch_first=True, padding_value=0.0)
    mask_embs = pad_sequence(all_masks, batch_first=True, padding_value=False)

    if max_len is not None:
        # Pad or truncate embeddings
        curr_len = padded_embs.size(1)
        if curr_len < max_len:
            pad_size = max_len - curr_len
            padding = torch.zeros(padded_embs.size(0), pad_size, padded_embs.size(2), device=padded_embs.device)
            padded_embs = torch.cat([padded_embs, padding], dim=1)

            mask_padding = torch.zeros(mask_embs.size(0), pad_size, dtype=torch.bool, device=mask_embs.device)
            mask_embs = torch.cat([mask_embs, mask_padding], dim=1)

        elif curr_len > max_len:
            padded_embs = padded_embs[:, :max_len, :]
            mask_embs = mask_embs[:, :max_len]

    return padded_embs, mask_embs, head_idxs





In [96]:
start_probs = torch.tensor([
    [.6,.6,.3,.3],
    [.1,.6,.3,.3],
    [.3,.6,.3,.3]
])
end_probs = torch.tensor([
    [.6,.3,.3,.6],
    [.3,.3,.3,.6],
    [.6,.3,.3,.3]
])


token_embs = torch.tensor([
    [[4,2,1,6, 5,9], [1,7,9,3,5,6], [4,7,1,7,3,2], [3,1,9,3,1,4]],
    [[4,2,1,6, 5, 9], [1,7,9,3,5,6], [4,7,1,7,3,2], [3,1,9,3,1,4]],
    [[4,2,1,6, 5, 9], [1,7,9,3,5,6], [4,7,1,7,3, 2], [3,1,9,3,1,4]],
])

In [97]:
padded_embs, mask_embs, head_idxs = extract_first_embeddings(token_embs, start_probs, end_probs, 4, threshold=.5)

padded_embs


sent idx: 0
    start idx: 0
    start idx: 1
        end ptr: 0

sent idx: 1
    start idx: 1

sent idx: 2
    start idx: 1
        end ptr: 0


tensor([[[4.0000, 2.0000, 1.0000, 6.0000, 5.0000, 9.0000],
         [2.6667, 5.0000, 6.3333, 4.3333, 3.0000, 4.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[2.6667, 5.0000, 6.3333, 4.3333, 3.0000, 4.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]]])