# Get embeddings

In [58]:
import os
from glob import glob
from transformers import BertTokenizer, BertModel
import torch

from eval import getembeddings
from eval import gpuutils

import importlib
importlib.reload(getembeddings)



ocr_data_dir = os.path.join(os.path.dirname(os.getcwd()), 'digitalize_handwritten')
groundtruth_dir = os.path.join(ocr_data_dir, 'data')
ocr_dir = os.path.join(ocr_data_dir, 'OCR')

In [59]:
# set model and tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertModel.from_pretrained("bert-base-uncased")

device = gpuutils.get_gpu_most_memory()
model.to(device)

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False

## Ground truth embeddings

In [60]:
text_glob_path = os.path.join(groundtruth_dir, 'BenthamDataset', 'GT', 'GT_Extracted', '*.txt')
save_dir = os.path.join(os.getcwd(), 'data', 'embeddings', 'ground_truth', 'bentham')

getembeddings.get_all_embeddings(text_glob_path, save_dir, model, tokenizer, device)


# Raw OCR embeddings

In [61]:
text_glob_path = os.path.join(ocr_dir, 'completed-OCR', 'Bentham', '*.txt')
save_dir = os.path.join(os.getcwd(), 'data', 'embeddings', 'ocr', 'bentham')

getembeddings.get_all_embeddings(text_glob_path, save_dir, model, tokenizer, device)
