In [5]:
import torch
import pandas as pd
from transformers import BertModel, BertTokenizer

In [6]:
data = pd.read_json('department_justice_press_releases.json', lines=True)
data.head()
data.shape

(13087, 6)

## Load pretrained BERT model and tokenizer

In [7]:
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

## Preprocessing

In [8]:
# Function to tokenize a single document
def get_document_embedding(document):
    tokenized_text = tokenizer.encode_plus(
        document,
        max_length=512,  # Max length for BERT input
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    with torch.no_grad():
        outputs = model(**tokenized_text)
        embeddings = outputs.last_hidden_state
    
    document_embedding = torch.mean(embeddings, dim=1)

    return document_embedding

data['document_embedding'] = data.apply(lambda row: get_document_embedding(row['contents']), axis=1)

KeyboardInterrupt: 

In [None]:
data.head()

{'input_ids': [101, 6734, 1010, 5392, 1012, 1516, 14467, 28609, 9587, 3511, 6784, 1010, 2603, 1010, 2040, 2001, 7979, 1999, 2286, 1997, 7161, 2000, 2224, 1037, 5195, 1997, 3742, 6215, 1006, 14792, 1007, 1999, 4434, 2007, 1037, 5436, 2000, 20010, 21149, 1037, 4316, 5968, 2012, 2019, 3296, 4234, 3392, 7497, 5103, 1999, 6734, 1010, 2001, 7331, 2651, 2000, 3710, 2382, 2086, 1999, 3827, 1010, 2628, 2011, 1037, 6480, 2744, 1997, 13588, 2713, 1012, 9587, 3511, 6784, 1010, 1037, 27558, 1057, 1012, 1055, 1012, 6926, 2013, 14717, 1998, 2280, 6319, 1997, 2522, 26585, 6856, 1010, 5392, 1010, 2001, 4727, 2006, 13292, 1012, 2656, 1010, 2230, 1010, 2044, 2002, 4692, 2000, 20010, 21149, 2054, 2002, 3373, 2000, 2022, 2019, 14792, 1011, 14887, 3158, 2008, 2001, 9083, 2379, 1996, 3392, 7497, 5103, 1999, 6734, 1012, 1996, 6545, 2001, 1996, 12731, 13728, 12758, 1997, 1037, 2146, 1011, 2744, 16382, 3169, 1010, 2076, 2029, 9587, 3511, 6784, 2001, 17785, 4876, 2005, 2706, 2004, 2010, 5968, 5436, 2764, 1012, 1