In [1]:
# in general we need to keep track of where the token of the word is.

In [2]:
import torch
import torch.nn as nn
from helper import *
from datasets import load_dataset
import spacy
from tqdm import tqdm
from torchmetrics.functional.pairwise import pairwise_cosine_similarity

[nltk_data] Downloading package punkt to
[nltk_data]     /home/enrico_benedetti/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
def get_spans(query, sentence, tokenizer, verbose=False):
    """input:
    query: the target word to find the span for. (if already a doc it should be faster)
    sentence: the sentence. (can be a doc - it will be faster)
    tokenizer: the tokenizer. assumed a direct __call__ method (e.g. from spacy tokenizers)

    output: 
    spans_ids_char, spans_ids_spacy: two lists of lists of tuples
    each list of tuples [(s, e)_i]: list of start and end idx of matches. one can take only the first one if needed.
    if no match is found, a tuple of empty lists is returned."""

    matcher = spacy.matcher.Matcher(tokenizer.vocab)
    # tokenize 
    if isinstance(sentence, spacy.tokens.doc.Doc):
        sentence_tokens = sentence
    else:
        sentence_tokens = tokenizer(sentence) # this is prob the heavy one

    if isinstance(query, spacy.tokens.doc.Doc):
        query_tokens = query
    else:
       query_tokens = tokenizer(query) # this is prob the heavy one - shorter tho
    
    pattern = [ {"LEMMA": query_token.lemma_} for query_token in query_tokens]
    matcher.add("query_match", [pattern])

    # get matches as id, start, end
    matches = matcher(sentence_tokens)
    # delete this particular one to avoid matching later
    matcher.remove("query_match")

    results_char = []
    results_spacy = []
    for match_id, start, end in matches:
        span_doc = sentence_tokens[start:end]
        #span_str = sentence_tokens.text[span_doc.start_char:span_doc.end_char]
        # if verbose:
        #     print(f'sentence: {sentence}')
        #     print(f'Span(token): {sentence_tokens[start:end]} , Span(str): {span_str}')
        #     print(f'str: {span_doc.start_char,span_doc.end_char} , indeces: {start, end}')
        results_char.append((span_doc.start_char, span_doc.end_char))
        results_spacy.append((start, end))

    return results_char, results_spacy

In [4]:
nlp = spacy.load('ja_ginza')



In [5]:
# load the mirorwic model

from transformers import AutoModel, AutoTokenizer

# Load the BERT Large model
model_name = "bennexx/mirrorwic-cl-tohoku-bert-base-japanese-v3"

model = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model.cuda()
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(32768, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [6]:
# Example input text
text = "Hugging Face is creating wonderful NLP models!"

# Tokenize the input text
input_ids = tokenizer.encode(text, add_special_tokens=True, truncation=True, padding=True, return_tensors="pt")

# Make predictions using the loaded model
with torch.no_grad():
    outputs = model(input_ids.cuda())

In [16]:
dataset = load_dataset("bennexx/jp_sentences")
df = dataset['train'].to_pandas()

Downloading readme:   0%|          | 0.00/286 [00:00<?, ?B/s]

In [17]:
df

Unnamed: 0,sentence
0,藤井氏の著書の販売から、ここでしか買えない音声メールマガジンのコーナーもあります。
1,朱肉のつけすぎは、他人に写し取られて悪用されかねませんので、注意。
2,特にこれに関して習得したいと思われるテーマがあれば気軽にリクエスト下さい。
3,それに伴い、各種問い合わせを受付開始いたします。
4,英語とか韓国語、中国語での出版も希望しています。
...,...
12704955,1883年9月4日、道路、運河、港湾、鉱山に関する学校として設立された。
12704956,1901年8月17日、高等工業学校が追加された。
12704957,1975年にはムルシア大学に組み込まれてカルタヘナ工科学校となった。
12704958,1998年8月3日にムルシア大学から分離され、大学としてのカルタヘナ工科大学が開学した。


In [9]:
#sentences = ["日本語能力試験は簡単です。", "私は、日本語能力試験したことがなかった。"]
sentences = ['明日、彼は休暇を取ります。', '彼は本を取って読んでいます。', '彼は賞を取ることができました。', '彼女は試験で高得点を取りました。']
#target_word = '日本語能力試験' 
target_word = '取る'

pretoken_sentences = list(nlp.pipe(sentences))
pretoken_sentences = [[t.orth_ for t in doc ]for doc in pretoken_sentences]
print(pretoken_sentences)
#tokenized_text = tokenizer.batch_encode_plus(pretoken_sentences, is_split_into_words=True, truncation = True, padding="max_length", max_length=50)
#print(tokenized_text)

[['明日', '、', '彼', 'は', '休暇', 'を', '取り', 'ます', '。'], ['彼', 'は', '本', 'を', '取っ', 'て', '読ん', 'で', 'い', 'ます', '。'], ['彼', 'は', '賞', 'を', '取る', 'こと', 'が', 'でき', 'まし', 'た', '。'], ['彼女', 'は', '試験', 'で', '高得点', 'を', '取り', 'まし', 'た', '。']]


In [10]:
# spans_ids_char, spans_ids_spacy = get_spans(query=target_word, sentence=sentences[0], tokenizer=nlp.tokenizer)
# # span_char = spans_ids_char[0]
# # span_spacy = spans_ids_spacy[0]
# target_word_start, target_word_end = spans_ids_char[0] # take first one
# sentences[0][target_word_start:target_word_end]

In [11]:
# subtoken_start, subtoken_end = sum(subtokens_per_token[:target_word_start]), sum(subtokens_per_token[:target_word_end])
# subtoken_start, subtoken_end = get_subtoken_span(subtokens_per_token, span_spacy)
# print(subtoken_start, subtoken_end)

In [12]:
# output = model(flattened_inputs.cuda(), output_hidden_states=True)
# layer_start = 9
# layer_end = 13
# hidden_states = output.hidden_states # the first number is the layer, second is the token number, last is the vector
# average_layer_batch = sum(hidden_states[layer_start:layer_end]) / (layer_end-layer_start) # we get the token mean across the last layers
# sentence_num_in_batch = average_layer_batch.size()[0]
# for num in range(sentence_num_in_batch): # for all sentences passed
#     # here we need the start/end on subtokens
#     ###

#     print(average_layer_batch[num].size()) 
#     sentence_embeddings = average_layer_batch[num] # get a tensor num_tokens(hf) x 768 vector values
#     print(sentence_embeddings[0].shape) # sentence_embeddings[i] is the embedding of subtoken i. to get the ones for the original words we need to...
#     sentence_embeddings_no_st = sentence_embeddings[1:-1]# remove first and last (they are the cls and sep ones)
#     print(sentence_embeddings_no_st.shape) # find the embeddings of interest (the one for a certain target word) and sum? also between them / or take the first one?
    
#     target_word_embeddings = sentence_embeddings_no_st[subtoken_start:subtoken_end] # those are the corresponding embeddings
#     target_word_emb_final = target_word_embeddings.detach().cpu().numpy().mean(0) # take the mean
#     print("mean", target_word_emb_final.shape)
#     target_word_emb_final = target_word_embeddings[0].detach().cpu().numpy() # take the first token only
#     print("first", target_word_emb_final.shape)

In [13]:
# sentence_spans = []
# subtoken_spans = []
# batch_flattened_inputs = []
# for sentence, pretoken_sentence in zip(sentences, pretoken_sentences):
#     # process tokens in the batch
#     # these 2 for every sentence in the batch?
#     flattened_inputs, grouped_inputs, subtokens_per_token = tokenize_for_sense_embeddings(pretokenized_sentence=pretoken_sentence, tokenizer=tokenizer)
#     batch_flattened_inputs.append(flattened_inputs.tolist()[0]) # depack the batch

#     spans_ids_char, spans_ids_spacy = get_spans(query=target_word, sentence=sentence, tokenizer=nlp.tokenizer)
#     spans_ids_char = spans_ids_char[0] # take first occurrence
#     spans_ids_spacy = spans_ids_spacy[0] # take first occurrence
#     sentence_spans.append(spans_ids_spacy)  
#     #print(sentence_spans)
#     #print(subtokens_per_token)
#     subtoken_span = get_subtoken_span(subtokens_per_token, spans_ids_spacy)
#     subtoken_spans.append(subtoken_span)

# # need to pad and stuff...
# #print(batch_flattened_inputs)
# batch_flattened_inputs = tokenizer.batch_encode_plus(batch_flattened_inputs, is_split_into_words=True, add_special_tokens=False, padding=True, truncation=True, return_tensors='pt')
# #batch_flattened_inputs = torch.stack(batch_flattened_inputs, dim=0)

# print(subtoken_spans)
# print(batch_flattened_inputs)


# embeddings = get_embeddings(flattened_inputs=batch_flattened_inputs, subtoken_spans=subtoken_spans, model=model)

# pairwise_cosine_similarity(torch.Tensor(embeddings), zero_diagonal=True)

In [54]:
# https://www.mrklie.com/post/2020-09-26-pretokenized-bert/
def tokenize_for_sense_embeddings(pretokenized_sentence, tokenizer):
    """Processes a pretokenized sentence
    Input:
    tokenizer: the hf tokenizer"""
    grouped_inputs = [torch.LongTensor([tokenizer.cls_token_id])]
    subtokens_per_token = []

    for token in pretokenized_sentence:
        tokens = tokenizer.encode(
            token,
            return_tensors="pt",
            add_special_tokens=False,
            padding=True,
            truncation=True
        ).squeeze(axis=0)
        grouped_inputs.append(tokens)
        subtokens_per_token.append(len(tokens))

    grouped_inputs.append(torch.LongTensor([tokenizer.sep_token_id]))

    flattened_inputs = torch.cat(grouped_inputs)
    flattened_inputs = torch.unsqueeze(flattened_inputs, 0)
    return flattened_inputs, grouped_inputs, subtokens_per_token

def get_subtoken_span(subtokens_per_token, span_spacy):
    """Gets the start and end of subtokens from the spacy span."""
    target_word_start, target_word_end = span_spacy
    return sum(subtokens_per_token[:target_word_start]), sum(subtokens_per_token[:target_word_end])

def get_embeddings(flattened_inputs, subtoken_spans, model, layer_start = 9, layer_end = 13, method='mean'):
    """Gets the embeddings from the model on the average of layers. 

    Input: 
    flattened_inputs: the tokenized (hf) ids, expects a batch.
    subtoken_span: the corresponding subtoken spans to extract. also a batch.
    layer_start, layer_end: which layers to average embeddings.
    method: first, mean, sum. does not change much
    
    output:
    a batch of final embeddings for the subtoken_spans which correspond to the target word."""

    output = model(input_ids=flattened_inputs['input_ids'].cuda(),attention_mask=flattened_inputs['attention_mask'].cuda(), output_hidden_states=True)
    hidden_states = output.hidden_states # the first number is the layer, second is the token number, last is the vector
    average_layer_batch = sum(hidden_states[layer_start:layer_end]) / (layer_end-layer_start) # we get the token mean across the last layers
    sentence_num_in_batch = average_layer_batch.size()[0]
    
    out_tensors = []

    for num in range(sentence_num_in_batch): # for all sentences passed
    # here we need the start/end on subtokens
    ###
        #print(average_layer_batch[num].size()) 
        sentence_embeddings = average_layer_batch[num] # get a tensor num_tokens(hf) x 768 vector values
        #print(sentence_embeddings[0].shape) # sentence_embeddings[i] is the embedding of subtoken i. to get the ones for the original words we need to...
        sentence_embeddings_no_st = sentence_embeddings[1:-1]# remove first and last (they are the cls and sep ones) - ok theres also the pad but they are right-side
        #print(sentence_embeddings_no_st.shape) # find the embeddings of interest (the one for a certain target word) and sum? also between them / or take the first one?
        
        subtoken_start, subtoken_end =  subtoken_spans[num] # get the current spans / the one of the batch
        target_word_embeddings = sentence_embeddings_no_st[subtoken_start:subtoken_end] # those are the corresponding embeddings
        if method=='sum':
            target_word_emb_final = target_word_embeddings.sum(0)
        elif method=='mean':
            target_word_emb_final = target_word_embeddings.mean(0)#.detach().cpu().numpy().mean(0) # take the mean
        elif method=='first':
            target_word_emb_final = target_word_embeddings[0]#.detach().cpu().numpy() # take the first token only
        # checking
        #print(target_word_emb_final)
        detokenized = tokenizer.batch_decode(flattened_inputs['input_ids'][num])[1:-1] # yes
        corresponding_tokens = detokenized[subtoken_start: subtoken_end]
        print(corresponding_tokens)

        #print("mean", target_word_emb_final.shape)
        
        #print("first", target_word_emb_final.shape)
        out_tensors.append(target_word_emb_final)

    out_tensor = torch.stack(out_tensors, dim=0)
    return out_tensor.detach().cpu().numpy()

def get_sense_similarity(sentences: list, target_word, nlp, model, tokenizer, method='mean'):
    """Returns the similarity matrix and the embeddings."""
    pretoken_sentences = list(nlp.pipe(sentences))
    pretoken_sentences = [[t.orth_ for t in doc ]for doc in pretoken_sentences]
    sentence_spans = []
    subtoken_spans = []
    batch_flattened_inputs = []
    # list of indeces to zero the distances for in the matrix
    i_zero_list = []
    for i, (sentence, pretoken_sentence) in enumerate(zip(sentences, pretoken_sentences)):
        # process tokens in the batch
        # these 2 for every sentence in the batch?
        flattened_inputs, grouped_inputs, subtokens_per_token = tokenize_for_sense_embeddings(pretokenized_sentence=pretoken_sentence, tokenizer=tokenizer)
        batch_flattened_inputs.append(flattened_inputs.tolist()[0]) # depack the batch

        spans_ids_char, spans_ids_spacy = get_spans(query=target_word, sentence=sentence, tokenizer=nlp.tokenizer)
        # if empty, remember the sentence and set the distance to 0 later. give a 0,0 subtoken span
        if len(spans_ids_char) == 0:
            subtoken_span = (0,0)
            i_zero_list.append(i)
        else:
            # do it normally
            spans_ids_char = spans_ids_char[0] # take first occurrence
            spans_ids_spacy = spans_ids_spacy[0] # take first occurrence
            sentence_spans.append(spans_ids_spacy)  
            #print(sentence_spans)
            #print(subtokens_per_token)
            subtoken_span = get_subtoken_span(subtokens_per_token, spans_ids_spacy)
        subtoken_spans.append(subtoken_span)

    # need to pad and stuff...
    #print(batch_flattened_inputs)
    batch_flattened_inputs = tokenizer.batch_encode_plus(batch_flattened_inputs, is_split_into_words=True, add_special_tokens=False, padding=True, truncation=True, return_tensors='pt')
    #batch_flattened_inputs = torch.stack(batch_flattened_inputs, dim=0)

    #print(subtoken_spans)
    #print(batch_flattened_inputs)


    embeddings = get_embeddings(flattened_inputs=batch_flattened_inputs, subtoken_spans=subtoken_spans, model=model, method=method)
    similarity_matrix = pairwise_cosine_similarity(torch.Tensor(embeddings), zero_diagonal=True)

    # for the parts where there are no target words, fill with -1s
    similarity_matrix[:, i_zero_list] = -1
    similarity_matrix[i_zero_list, :] = -1
    # re_zero diagonal
    similarity_matrix.fill_diagonal_(0)
    #print(i_zero_list)

    return similarity_matrix, embeddings

In [55]:
# test on 100 sentences
# what to do if not found... probably give zero and skip - or consider that they will be there for sure
#get_sense_similarity(df['sentence'][:100], target_word='target_word', nlp=nlp, model=model, tokenizer=tokenizer)
get_sense_similarity(sentences, target_word=target_word, nlp=nlp, model=model, tokenizer=tokenizer)

['取 り']
['取 っ']
['取 る']
['取 り']


(array([[ 0.08049724,  0.00812712, -0.22313227, ..., -0.19238411,
          0.07798383, -0.4647833 ],
        [-0.3224555 , -0.45918953, -0.52518886, ...,  0.4306293 ,
          0.18063447,  0.6098743 ],
        [-0.02311909,  0.8989869 , -0.79339576, ..., -0.91200066,
          0.5188445 , -0.08833297],
        [-0.2956479 ,  0.40284643, -0.17048827, ..., -0.13746454,
          1.1316245 , -0.4609816 ]], dtype=float32),
 tensor([[0.0000, 0.4925, 0.4664, 0.5221],
         [0.4925, 0.0000, 0.4587, 0.4700],
         [0.4664, 0.4587, 0.0000, 0.7712],
         [0.5221, 0.4700, 0.7712, 0.0000]]))