# Get embeddings

In [31]:
import os
from glob import glob
from transformers import BertTokenizer, BertModel
import torch

from eval import getembeddings
from eval import gpuutils

import importlib
importlib.reload(getembeddings)



ocr_data_dir = os.path.join(os.path.dirname(os.getcwd()), 'digitalize_handwritten')
groundtruth_dir = os.path.join(ocr_data_dir, 'data')
ocr_dir = os.path.join(ocr_data_dir, 'OCR')

In [33]:
# set model and tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

device = gpuutils.get_gpu_most_memory()
model.to(device)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

## Ground truth embeddings

In [None]:
groundtruth_fns = glob(os.path.join(groundtruth_dir, 'BenthamDataset', 'GT', 'GT_Extracted', '*.txt'))

for gt_fn in groundtruth_fns:
    with open(gt_fn) as f:
        gt_text = f.read()

    save_path = os.path.join(os.getcwd(), 'data', 'groundtruth_dir', 'bentham', os.path.basename(gt_fn).replace('.txt', '.pt'))

    if os.path.exists(save_path): continue # dont need to repeat embedding computation
    embedding = getembeddings.get_embedding(gt_text, model=model, tokenizer=tokenizer, device=device, max_len=512)
    
    break
    


    

{'input_ids': tensor([[  101,  4029,  2575,  1012,  1996,  3350,  1997,  1996,  8147,  1010,
          9530,  5332, 19225,  2000,  1037, 12109,  6602,  1010,  2612,  1997,
          1037,  4964,  2338,  1012,  1035,  2579,  2013,  4654,  5403,  1011,
         10861,  2099,  8236,  1010,  1035, 12980,  2013,  4518,  5754, 14663,
          3111,  1035, 10217,  2007,  3493,  2139, 10609, 22662,  1010,  1998,
          1996,  2085, 26958,  3212, 10967,  8525,  1011,  2035,  2075,  3665,
          1998, 14445,  8236,  2030,  2139, 10609, 22662,  1024,  1035,  2036,
          2007,  2634,  9547,  1010,  2924,  3964,  1010, 13448,  1005,  1055,
         20877, 14643, 10253,  3964,  1025,  1998,  2797, 20877, 14643, 10253,
          3964,  1998,  8236,  1997,  3863,  1012,  1021,  1012,  1996,  3259,
          1010,  2011,  2049,  2946,  1010,  4338,  1010, 14902,  1010,  2024,
          4857,  2791,  1010, 11968,  1011,  1011, 14841, 15431,  2135,  7130,
          2005,  9141,  1035,  2579,  