In [34]:
from generationary_paired_data_creator import get_closest_words_with_sentences
from data.db import DbConnection
from typing import Tuple, Optional
import numpy as np
import spacy
from nlp.embedding import EmbeddingExtractor

nlp = spacy.load('en_core_web_md')
extractor = EmbeddingExtractor()

word_db = DbConnection('/home/josh/scrapbox/toudai/cuwsdm_words')
sentence_db = DbConnection('/home/josh/scrapbox/toudai/cuwsdm_sentences')


def embed_word(sentence: str, word:str) -> Tuple[str, str, int, np.ndarray]:
    doc = nlp(sentence)

    for token, embedding in extractor.get_word_embeddings(doc):
        if token.text == word:
            return token.text, token.lemma_, token.pos, embedding

    raise StopIteration(f"Couldn't find target word {word}")



def get_closest_words(target_lemma, target_embedding, target_pos: Optional[int] = None):
    if target_pos:
        where_clause = f'where pos={target_pos} and (form=\'{target_lemma}\' or lemma=\'{target_lemma}\')'
    else:
        where_clause = f'where form=\'{target_lemma}\' or lemma=\'{target_lemma}\''

    print(where_clause)

    words = list(word_db.read_words(use_tqdm=False, where_clause=where_clause))

    if len(words) == 0:
        raise RuntimeError(f'{target_lemma} not found in db')

    return get_closest_words_with_sentences(target_embedding, words, sentence_db)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [35]:
form, lemma, pos, embedding = embed_word('I hate lizard people', 'hate')

In [36]:
closest_words_with_sentences = get_closest_words(lemma, embedding)


where form='hate' or lemma='hate'


[Word(id=460678, form='hates', lemma='hate', pos=100, sentence='He also hates pumpernickel bread.', embedding=array([ 5.3497e-02,  1.8237e-01, -2.7771e-02,  1.8774e-01, -6.4160e-01,
         1.4172e-01,  2.1704e-01,  3.8037e-01, -5.5161e-03, -1.2756e-01,
         3.2080e-01, -1.5991e-01,  2.2095e-02,  3.0298e-01, -2.0844e-02,
         8.0566e-02,  4.0558e-02, -3.4790e-02, -9.7168e-02,  2.3840e-01,
         1.1249e-01,  1.4026e-01,  2.6221e-01,  2.2974e-01,  2.1863e-01,
        -1.5747e-02,  1.8628e-01, -3.3765e-01, -3.4888e-01, -1.6296e-01,
        -7.1228e-02,  3.0127e-01,  4.7461e-01, -3.0615e-01, -1.7578e-01,
        -3.1714e-01, -2.6953e-01,  3.3105e-01,  6.3232e-02,  2.9272e-01,
         1.3208e-01,  3.2471e-02,  2.5366e-01,  2.1289e-01, -7.8369e-02,
         9.6313e-02,  5.6250e-01,  2.3743e-01, -9.3323e-02, -3.1689e-01,
         4.8535e-01,  1.0736e-01, -1.5869e-02, -6.5674e-02, -1.3843e-01,
        -1.3660e-01, -3.2104e-01, -1.9119e-02,  2.3608e-01, -3.1958e-01,
        -2.2778