In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

def generate_embedding(text):
    """
    Generate embedding for a given text using BioClinicalBERT.
    """
    if pd.isna(text) or not isinstance(text, str):
        return None
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        cls_embedding = outputs.last_hidden_state[:, 0, :].squeeze().cpu().tolist()
    return cls_embedding

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
if __name__ == "__main__":
    # File paths
    # input_file = 'data/Social_History_extraction_full_v1.csv'
    # output_file = 'data/social_history_embeddings_output.csv'
    # text_column = 'SOCIAL_HISTORY'

    input_file  = '../../data/text/summary_NOTEEVENTS_60_150_2.csv'
    output_file = '../../data/emb/summary_NOTEEVENTS_60_150_2.csv'
    text_column = 'SUMMARY'

    # Load BioClinicalBERT model
    print("Loading BioClinicalBERT...")
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print("Model loaded.")

    # Load input data
    print("Loading input data...")
    data = pd.read_csv(input_file)
    print(f"Input data loaded. Total rows: {len(data)}")
    print(data.head(5))

    # print("Ensuring consistent data types for the text column...")
    # data[text_column] = data[text_column].fillna("").astype(str)

    # # Generate embeddings
    # print("Generating embeddings...")
    # embeddings = []
    # for index, row in tqdm(data.iterrows(), total=len(data), desc="Generating Embeddings"):
    #     embeddings.append(generate_embedding(row[text_column]))

    # # Save embeddings
    # print("Saving embeddings to output file...")
    # data['EMBEDDINGS'] = embeddings
    # data.to_csv(output_file, index=False)
    # print(f"Embeddings saved to {output_file}")


Loading BioClinicalBERT...
Model loaded.
Loading input data...
Input data loaded. Total rows: 59652
   ROW_ID  SUBJECT_ID   HADM_ID  \
0     174       22532  167853.0   
1     175       13702  107527.0   
2     176       13702  167118.0   
3     177       13702  196489.0   
4     178       26880  135453.0   

                                             SUMMARY  
0  radiologic studies also included a chest CT, w...  
1  this is an 81-year-old female with a history o...  
2  this 81 year old woman has a history of COPD ....  
3  EMS found patient tachypnic at saturating 90% ...  
4  Mr. [**Known lastname 1829**] was seen at [**H...  
