Things to choose:

"g_lex_utf8" or "lex" as representation of the lexeme part of the word.
g_lex_utf8 is the words with the prefixes and suffixes removed.
lex is th elexeme of the word, with disambiguation signs.
lex has the advantage of having standardized spelling and disambiguation.

With or without morpheme markers.
With morpheme markers it is possible to distinguish between J of the yiqtol and the pronominal suffix. Without markers these are seen as the same token.

Maybe:
Partly overlapping or no overlapping sequences. The problem with partly overlapping sequences is that parts of the same text occur both in raining and test set.



In [6]:
import collections, itertools, os

import numpy as np
import pandas as pd
import pickle

import torch
from torchinfo import summary
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoModel,
    BertConfig, 
    BertModel, 
    DataCollatorForLanguageModeling, 
    AutoModelForMaskedLM,
    BertForMaskedLM,
    PreTrainedTokenizerFast,
    Trainer,
    TrainingArguments
)

from tokenizers.models import WordLevel
from tokenizers.trainers import WordLevelTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing

from tokenizers import (
    decoders,
    models,
    normalizers,
    pre_tokenizers,
    processors,
    trainers,
    Tokenizer,
)

from tf.app import use
A = use('etcbc/bhsa', hoist=globals())

A.load([
        'g_lex_utf8', 'g_prs_utf8', 'g_nme_utf8', 'g_pfm_utf8', 'g_vbs_utf8', 'g_vbe_utf8', 'g_uvf_utf8'
       ])


**Locating corpus resources ...**

Name,# of nodes,# slots / node,% coverage
book,39,10938.21,100
chapter,929,459.19,100
lex,9230,46.22,100
verse,23213,18.38,100
half_verse,45179,9.44,100
sentence,63717,6.7,100
sentence_atom,64514,6.61,100
clause,88131,4.84,100
clause_atom,90704,4.7,100
phrase,253203,1.68,100


True

In [7]:
MODEL_DIR = './models'
seq_length = 5
augment_factor = 10

# model details
num_hidden_layers = 6
num_attention_heads = 8

In [8]:
relevant_chars_utf8 = {' ',
 'א',
 'ב',
 'ג',
 'ד',
 'ה',
 'ו',
 'ז',
 'ח',
 'ט',
 'י',
 'ך',
 'כ',
 'ל',
 'ם',
 'מ',
 'ן',
 'נ',
 'ס',
 'ע',
 'ף',
 'פ',
 'ץ',
 'צ',
 'ק',
 'ר',
 'ש',
 'ת'}

In [9]:
alphabet_dict_heb = {char: char for char in relevant_chars_utf8}
double_chars = ['ןנ','ףפ', 'ץצ','ךכ','םמ']

for end_char, non_end_char in double_chars:
    alphabet_dict_heb[end_char] = non_end_char

In [10]:
# needed for conversion of lex to hebrew script, including markers =, / and [

alphabet_dict_heb_lat = {'א': '>',
                                      'ב': 'B',
                                      'ג': 'G',
                                      'ד': 'D',
                                      'ה': 'H',
                                      'ו': 'W',
                                      'ז': 'Z',
                                      'ח': 'X',
                                      'ט': 'V',
                                      'י': 'J',
                                      'כ': 'K',
                                      'ל': 'L',
                                      'מ': 'M',
                                      'נ': 'N',
                                      'ס': 'S',
                                      'ע': '<',
                                      'פ': 'P',
                                      'צ': 'Y',
                                      'ק': 'Q',
                                      'ר': 'R',
                                      'ש': 'C',
                                      'ת': 'T'}

alphabet_dict_lat_heb = {v:k for k,v in alphabet_dict_heb_lat.items()}

alphabet_dict_lat_heb['_'] = ' '
alphabet_dict_lat_heb['F'] = 'ש' + 'ׂ'
alphabet_dict_lat_heb['/'] = 'ֶ' # nouns/adjectives
alphabet_dict_lat_heb['['] = 'ַ' # verbs
alphabet_dict_lat_heb['='] = 'ֻ' # lex disambiguation marker

new_chars = ['ש', 'ׂ', 'ֶ', 'ַ', 'ֻ']

In [11]:
# Are morphemes together always identical with g_cons? 

for w in F.otype.s('word'):
    g_cons = ''.join([alphabet_dict_heb[char] for char in F.g_cons_utf8.v(w) if char in relevant_chars_utf8])
    morphs = [F.g_pfm_utf8.v(w), F.g_vbs_utf8.v(w), F.g_lex_utf8.v(w), F.g_vbe_utf8.v(w), F.g_nme_utf8.v(w), F.g_uvf_utf8.v(w), F.g_prs_utf8.v(w)]
    reconstr = ''.join([''.join([alphabet_dict_heb[char] for char in morph if char in relevant_chars_utf8]) for morph in morphs])
    if g_cons != reconstr:
        print(w, g_cons, reconstr, F.g_uvf_utf8.v(w))

In [12]:
nme_marker =  '֜'
pfm_marker =  'ְ'
vbs_marker =  'ֱ'
vbe_marker =  'ֲ'
prs_marker =  'ֳ'
uvf_marker =  'ִ'

morpheme_markers = {nme_marker, pfm_marker, vbs_marker, vbe_marker, prs_marker, uvf_marker}

# keys indicate indices of morphemes in a word
morph_marker_dict = {
    4:  '֜',
    0:  'ְ',
    1:  'ֱ',
    3:  'ֲ',
    6:  'ֳ',
    5:  'ִ'
}

In [13]:
all_chars = set()
all_chars_utf8 = set()

for w in F.otype.s('word'):
    
    morphemes_utf8 = [F.g_lex_utf8.v(w), F.g_nme_utf8.v(w), F.g_pfm_utf8.v(w), F.g_vbs_utf8.v(w), F.g_vbe_utf8.v(w), F.g_uvf_utf8.v(w), F.g_prs_utf8.v(w)]
    for morph_utf8 in morphemes_utf8:
        all_chars_utf8.update(set(morph_utf8))

    morphemes = [F.g_lex_utf8.v(w), F.g_nme_utf8.v(w), F.g_pfm_utf8.v(w), F.g_vbs_utf8.v(w), F.g_vbe_utf8.v(w), F.g_prs_utf8.v(w), F.g_uvf_utf8.v(w)]
    for morph in morphemes:
        
        all_chars.update(set(morph))

all_chars

{' ',
 '֜',
 'ְ',
 'ֱ',
 'ֲ',
 'ֳ',
 'ִ',
 'ֵ',
 'ֶ',
 'ַ',
 'ָ',
 'ֹ',
 'ֻ',
 'ּ',
 'ֿ',
 'ׁ',
 'ׂ',
 'א',
 'ב',
 'ג',
 'ד',
 'ה',
 'ו',
 'ז',
 'ח',
 'ט',
 'י',
 'ך',
 'כ',
 'ל',
 'ם',
 'מ',
 'ן',
 'נ',
 'ס',
 'ע',
 'ף',
 'פ',
 'ץ',
 'צ',
 'ק',
 'ר',
 'ש',
 'ת'}

In [14]:
def make_non_overlapping_n_grams(input_list, n):
  return [input_list[i:i+n] for i in range(0, len(input_list), n)]  #zip(*[input_list[i:] for i in range(n)])

def make_n_clause_dict(n):
    """
    Makes sequences of n clauses in the Hebrew Bible, based on a running window.
    """
    n_clause_dict = {}

    for bo in F.otype.s('book'):
        cl_n_grams = list(make_non_overlapping_n_grams(L.d(bo, 'clause'), n))
        
        for cl_n_gram in cl_n_grams:
            ch = L.u(cl_n_gram[0], 'chapter')[0]
            book, chapter_number = T.sectionFromNode(ch)
            
            words_n_clause = sorted(list(itertools.chain(*[L.d(cl, 'word') for cl in cl_n_gram])))
            n_clause_dict[(book, chapter_number, cl_n_gram, 0)] = words_n_clause

    return n_clause_dict

In [15]:
def process_one_word(w, lex_representation, relevant_chars_utf8, alphabet_dict_heb):
    if lex_representation == 'g_lex_utf8':
        morphs = [F.g_pfm_utf8.v(w), F.g_vbs_utf8.v(w), F.g_lex_utf8.v(w), F.g_vbe_utf8.v(w), F.g_nme_utf8.v(w), F.g_uvf_utf8.v(w), F.g_prs_utf8.v(w)]
    elif lex_representation == 'lex':
        lex_rep = convert_lex_to_heb_script(w)
        morphs = [F.g_pfm_utf8.v(w), F.g_vbs_utf8.v(w), lex_rep, F.g_vbe_utf8.v(w), F.g_nme_utf8.v(w), F.g_uvf_utf8.v(w), F.g_prs_utf8.v(w)]
        relevant_chars_utf8, alphabet_dict_heb = update_char_dicts(relevant_chars_utf8, alphabet_dict_heb, new_chars)

    morph_list = [''.join([alphabet_dict_heb[char] for char in morph if char in relevant_chars_utf8]) for morph in morphs]
        
    morph_list_with_markers = []
    for idx, morph in enumerate(morph_list):
        if morph:
            morph = morph + morph_marker_dict.get(idx, '')
            morph_list_with_markers.append(morph)
                
    morph_string_with_markers = ' '.join(morph_list_with_markers)
    morph_string_with_markers = ' '.join(morph_string_with_markers.split())
    
    morph_string_without_markers = ' '.join(morph_list)
    morph_string_without_markers = ' '.join(morph_string_without_markers.split())

    return morph_string_with_markers, morph_string_without_markers

In [16]:
def convert_lex_to_heb_script(tf_word_id):
    return ''.join([alphabet_dict_lat_heb[char] for char in F.lex.v(tf_word_id)])

def update_char_dicts(relevant_chars_utf8, alphabet_dict_heb, new_chars):
    for new_char in new_chars:
        relevant_chars_utf8.add(new_char)
        alphabet_dict_heb[new_char] = new_char
    return relevant_chars_utf8, alphabet_dict_heb
    

def make_morpheme_dicts(n_clause_dict, lex_representation, relevant_chars_utf8, alphabet_dict_heb):
    """
    returns:
    all_morph_strings_with_markers
    keys: (book: str, (clause ids))
    values: hebrew string with morphemes as separate words (with markers) for morpheme types
    """
    all_morph_strings_with_markers = {}
    all_morph_strings_without_markers = {}

    for key, words in n_clause_dict.items():
        morphemes_in_clause_with_markers = []
        morphemes_in_clause_without_markers = []
    
        for w in words:
            morph_string_with_markers, morph_string_without_markers = process_one_word(w, lex_representation, relevant_chars_utf8, alphabet_dict_heb)
            
            morphemes_in_clause_with_markers.append(morph_string_with_markers)
            morphemes_in_clause_without_markers.append(morph_string_without_markers)

        all_morph_strings_with_markers[key] = ' '.join(morphemes_in_clause_with_markers)
        all_morph_strings_without_markers[key] = ' '.join(morphemes_in_clause_without_markers)

    return all_morph_strings_with_markers, all_morph_strings_without_markers

In [17]:
n_clause_dict = make_n_clause_dict(seq_length)

In [18]:
len(n_clause_dict)

17640

In [19]:
person_names = {w: F.lex.v(w) for w in F.otype.s('word') if F.nametype.v(w) == 'pers'}
person_names_swapped = {v:k for k, v in person_names.items()}
person_names_swapped = {v:k for k, v in person_names.items()}
unique_person_ids = list(person_names_swapped.values())

topo_names = {w: F.lex.v(w) for w in F.otype.s('word') if F.nametype.v(w) == 'topo'}
topo_names_swapped = {v:k for k, v in topo_names.items()}
unique_topo_ids = list(topo_names_swapped.values())


In [20]:
import random

n_clause_dict_augmented = {}

for key, words in n_clause_dict.items():
    bo, ch, cl, nr = key
    n_clause_dict_augmented[key] = words
    words_copy = words.copy()
    for i in range(1, augment_factor):
        for idx, tf_id in enumerate(words):
            if  F.nametype.v(tf_id) == 'pers':
                new_name_id = random.choice(unique_person_ids)
                words_copy[idx] = new_name_id
            elif F.nametype.v(tf_id) == 'topo':
                new_topo_id = random.choice(unique_topo_ids)
                words_copy[idx] = new_topo_id
    
        key_i = (bo, ch, cl, i)
        n_clause_dict_augmented[key_i] = words_copy

len(n_clause_dict_augmented)

176400

In [21]:
# choose 'lex' or 'g_lex_utf8'
# g_lex_utf is the lexeme part of the word, but as it is found in the manuscript. A word can have various spellings
# lex is based on the lex feature. This means it has standard spelling and lexeme disambiguation.

morpheme_dataset, _ = make_morpheme_dicts(n_clause_dict_augmented, 'lex', relevant_chars_utf8, alphabet_dict_heb)

hebrew_bible_tokens = set()
for token_string in morpheme_dataset.values():
    token_set = set(token_string.split())
    hebrew_bible_tokens.update(token_set)

In [27]:
with open(f'xbib_dict_len_{seq_length}.pkl', 'rb') as f:
    xbib_morphemes_dict = pickle.load(f)

In [28]:
with open(f'syriac_dict_len_25_augm_10_s4.pkl', 'rb') as f:
    syriac_morphemes_dict = pickle.load(f)

In [29]:
morpheme_dataset = {**morpheme_dataset, **xbib_morphemes_dict}
len(morpheme_dataset)

178185

In [30]:
morpheme_dataset = {**morpheme_dataset, **syriac_morphemes_dict}
len(morpheme_dataset)

191196

In [31]:
# Remove duplicates in values (= texts)

swapped_morpheme_dataset = {v:k for k, v in morpheme_dataset.items()}
morpheme_dataset = {v:k for k, v in swapped_morpheme_dataset.items()}
len(morpheme_dataset)     

43039

In [32]:
with open('hebrew_bible_morphemes.txt', 'w', encoding='utf8') as f:
    for heb_text in morpheme_dataset.values():
        f.write(heb_text + '\n')        

In [33]:
unk_token = '[UNK]'
special_tokens = ['[UNK]', '[CLS]', '[SEP]', '[PAD]', '[MASK]']

tokenizer = Tokenizer(WordLevel(unk_token = unk_token))
trainer = WordLevelTrainer(special_tokens = special_tokens)
tokenizer.pre_tokenizer = Whitespace()

tokenizer.post_processor = TemplateProcessing(
    single = '[CLS] $A [SEP]',
    special_tokens = [('[CLS]', 1), ('[SEP]', 2)]
)

In [34]:
tokenizer.train(['hebrew_bible_morphemes.txt'], trainer)

tokenizer.save('morphemes_with_markers_tokenizer.json')
tokenizer = PreTrainedTokenizerFast(tokenizer_file = 'morphemes_with_markers_tokenizer.json')

tokenizer.add_special_tokens({'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

0

In [35]:
len(tokenizer)

12296

In [36]:
tokenizer

PreTrainedTokenizerFast(name_or_path='', vocab_size=12296, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=False, added_tokens_decoder={
	0: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	1: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	2: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	3: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	4: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}
)

In [37]:
bib_df = pd.Series(morpheme_dataset).to_frame('text')
bib_df_shuffled = bib_df.sample(frac=1, replace=False, ignore_index=True)
bib_ds = Dataset.from_pandas(bib_df_shuffled)

In [38]:
def tokenize(sentence):
  return tokenizer(sentence['text'], max_length=128, truncation=True, padding=True)

tokenized_data = bib_ds.map(tokenize, batched=True)
tokenized_data.set_format("pt", columns=["input_ids", "attention_mask"], output_all_columns=True)
tokenized_data = tokenized_data.train_test_split(test_size=0.2)


Map:   0%|          | 0/43039 [00:00<?, ? examples/s]

In [39]:
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=True, mlm_probability=.15)

In [40]:
num_hidden_layers

6

In [41]:
config = BertConfig.from_pretrained(
    'bert-base-multilingual-cased', 
    model_type='bert',
    attention_probs_dropout_prob=.5, 
    hidden_dropout_prob=.5, 
    hidden_size=256,
    intermediate_size=1024,
    max_position_embeddings=128,
    
    num_attention_heads=num_attention_heads,
    num_hidden_layers=num_hidden_layers,
    vocab_size=len(tokenizer.vocab)
    )

model = BertForMaskedLM.from_pretrained('bert-base-multilingual-cased', config=config, ignore_mismatched_sizes=True)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.output.LayerNorm.weight', 'bert.encoder.layer.10.attention.output.dense.bias', 'bert.encoder.layer.10.attention.output.dense.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.10.attention.self.key.weight', 'bert.encoder.layer.10.attention.self.query.bias', 'bert.encoder.layer.10.attention.self.query.weight', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.10.intermediate.dense.weight', 'bert.encoder.layer.10.output.LayerNorm.bias', 'bert.encoder.layer.10.output.LayerNorm.weight', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.11.attention.output.LayerNorm.bias', 'bert.en

In [42]:
summary(model)

Layer (type:depth-idx)                                       Param #
BertForMaskedLM                                              --
├─BertModel: 1-1                                             --
│    └─BertEmbeddings: 2-1                                   --
│    │    └─Embedding: 3-1                                   3,147,776
│    │    └─Embedding: 3-2                                   32,768
│    │    └─Embedding: 3-3                                   512
│    │    └─LayerNorm: 3-4                                   512
│    │    └─Dropout: 3-5                                     --
│    └─BertEncoder: 2-2                                      --
│    │    └─ModuleList: 3-6                                  4,738,560
├─BertOnlyMLMHead: 1-2                                       --
│    └─BertLMPredictionHead: 2-3                             --
│    │    └─BertPredictionHeadTransform: 3-7                 66,304
│    │    └─Linear: 3-8                                      3,160,072
Tota

In [44]:
f'morphs_marks_lex_{num_hidden_layers}_layers_{num_attention_heads}_att_heads_{seq_length}_seqlen_{augment_factor}_augm_with_xbib_syr_10_augm'

'morphs_marks_lex_6_layers_8_att_heads_5_seqlen_10_augm_with_xbib_syr_10_augm'

In [None]:
def randomize_model(model):
    for module_ in model.named_modules(): 
        if isinstance(module_[1],(torch.nn.Linear, torch.nn.Embedding)):
            module_[1].weight.data.normal_(mean=0.0, std=model.config.initializer_range)
        elif isinstance(module_[1], torch.nn.LayerNorm):
            module_[1].bias.data.zero_()
            module_[1].weight.data.fill_(1.0)
        if isinstance(module_[1], torch.nn.Linear) and module_[1].bias is not None:
            module_[1].bias.data.zero_()
    return model


model = randomize_model(model)

args = TrainingArguments(output_dir=f'morphs_marks_lex_{num_hidden_layers}_layers_{num_attention_heads}_att_heads_{seq_length}_seqlen_{augment_factor}_augm_with_xbib_syr_10_augm', 
                         save_strategy='epoch',
                         learning_rate=0.0001,
                         num_train_epochs=80,
                         per_device_train_batch_size=8, 
                         per_device_eval_batch_size=8,
                         seed=42,
                        )
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_data['train'],
    eval_dataset=tokenized_data['test'],
    processing_class=tokenizer,
    data_collator=data_collator,
  )

trainer.train()
trainer.push_to_hub()


#model_name = f'bert_mt_morphemes_with_markers_based_on_lexemes_{num_hidden_layers}_layers_{num_attention_heads}_att_heads'
#trainer.save_model(os.path.join(MODEL_DIR, model_name))

Step,Training Loss
500,6.5399
1000,5.9586
1500,5.7696
2000,5.6593
2500,5.5814
3000,5.5349
3500,5.448
4000,5.4828
4500,5.4195
5000,5.4162


In [58]:
f'morphs_with_mark_based_on_lex_{num_hidden_layers}_layers_{num_attention_heads}_att_heads_{seq_length}_seqlen_{augment_factor}_augm_with_xbib_syr_no_disamb'

'morphs_with_mark_based_on_lex_4_layers_8_att_heads_5_seqlen_10_augm_with_xbib_syr_no_disamb'

In [36]:
model.eval()  

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_top_k_predictions(text, k=5):
    # Tokenize input with masking
    inputs = tokenizer(text, return_tensors="pt").to(device)
    mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]

    # Run inference
    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    mask_token_logits = logits[0, mask_token_index, :]

    # Get top k tokens
    top_k_tokens = torch.topk(mask_token_logits, k, dim=1).indices[0].tolist()
    return tokenizer.convert_ids_to_tokens(top_k_tokens)

In [None]:
raw_text = 'ב ראשית ברא אלה ֜ימ את ה שמי ֜מ ו את ה ארצ ו ה ארצ הי ֲתה תהו ו בהו ו חשכ על פנ ֜י תהומ ו רוח אלה ֜ימ ְמ רחפ ֜ת על פנ ֜י ה מי ֜מ'
raw_text = 'ב ראשיתֶ בראַ אלהימֶ ֜ימ את ה שמימֶ ֜מ ו את ה ארצֶ ו ה ארצֶ היהַ ֲתה תהוֶ ו בהוֶ ו חשכֶ על פנהֶ ֜י תהומֶ ו רוחֶ אלהימֶ ֜ימ ְמ רחפַ ֶ֜ת על פנהֶ ֜י ה מימֶ ֜מ'

text_split = raw_text.split()
mask_char = '[MASK]'

for i in range(5):
    masked_text = ' '.join([text_split[i], text_split[i+1], mask_char, text_split[i+3], text_split[i+4], text_split[i+5], text_split[i+6]])
    print(masked_text)
    print('True value:', text_split[i+2])
    top_predictions = get_top_k_predictions(masked_text, k=5)
    print('Top prediction:', top_predictions[0])
    print('Predictions', top_predictions)
    print()

In [None]:
model_name = 'martijn75/bert_mt_morphemes_with_markers_based_on_lexemes_2_layers_4_att_heads'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = BertForMaskedLM.from_pretrained(model_name, 
                                        return_dict_in_generate=True, 
                                        output_hidden_states=True).to(device)

model.eval()

In [None]:
def get_hidden_states(morpheme_dataset, model, tokenizer):
    hidden_states = {}
    for key, texts_chunk in morpheme_dataset.items():
        tokenized_inputs = tokenizer(texts_chunk, max_length=128, truncation=True, padding=True, return_tensors="pt")
        tokenized_inputs = {k:v.to(device) for k,v in tokenized_inputs.items()}
        with torch.no_grad():
            outputs = model(**tokenized_inputs)
            last_hidden_states = outputs.hidden_states[-1].cpu().numpy()
            hidden_states[key] = last_hidden_states

    return hidden_states

In [None]:
def calculate_mean_hidden_states(hidden_states):
    hidden_states_mean = {}

    for key, hs in hidden_states.items():
        mean_state = np.mean(hs, 1)
        mean_state = np.squeeze(mean_state)
        hidden_states_mean[key] = mean_state
        
    return hidden_states_mean

In [None]:
hidden_states = get_hidden_states(morpheme_dataset, model, tokenizer)
mean_hidden_states = calculate_mean_hidden_states(hidden_states)

In [None]:
from scipy.spatial import distance

emb1 = mean_hidden_states.get(('Genesis', 1, (427559, 427560, 427561, 427562)))
emb2 =  mean_hidden_states.get(('Nehemiah', 10, (509499, 509500, 509501, 509502)))
distance.cosine(emb1, emb2)

In [None]:
book = 'Jeremiah'

jer_keys = [key for key in mean_hidden_states.keys() if key[0] == book]
non_jer_keys = [key for key in mean_hidden_states.keys() if key[0] != book]


In [None]:
from scipy.spatial import distance

dim = (len(jer_keys), len(non_jer_keys))

cosine_dists = np.zeros(dim)

for jer_idx, jer_key in enumerate(jer_keys):
    for non_jer_idx, non_jer_key in enumerate(non_jer_keys):
        jer_emb = mean_hidden_states.get(jer_key)
        non_jer_emb =  mean_hidden_states.get(non_jer_key)
        dist = distance.cosine(jer_emb, non_jer_emb)
        cosine_dists[jer_idx, non_jer_idx] = dist


In [None]:
min_jer_dict = collections.defaultdict(list)

min_indices = np.argmin(cosine_dists, axis=1)
len(min_indices)

for jer_idx, non_jer_idx in enumerate(min_indices):
    jer_key = jer_keys[jer_idx]
    bo, ch, _ = jer_key
    non_jer_key = non_jer_keys[non_jer_idx]
    non_jer_bo, non_jer_ch, _ = non_jer_key
    min_jer_dict[(bo, ch)].append((non_jer_bo, non_jer_ch))

    print(cosine_dists[jer_idx, non_jer_idx])
    print(non_jer_key, jer_key)
    print()


In [None]:
min_jer_dict