In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
import os

def generate_embeddings_batch(texts, tokenizer, model, device):
    """
    Generate CLS token embeddings for a batch of texts using BioClinicalBERT.
    """
    inputs = tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        ).to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().tolist()
    return cls_embeddings

if __name__ == "__main__":
    # Load BioClinicalBERT model
    print("Loading BioClinicalBERT model...")
    tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    print("Model loaded.")
        
    # ===============================================================================  
    # File paths
    thePath = os.getenv('thePath', '../../../Data/unstructured')

    # Input and output paths
    input_path = f'{thePath}/summarized/merged_summaries.csv'
    output_path = f'{thePath}/emb/merged_embeddings.csv'

    # ===============================================================================  
    # Load input data
    print("Loading input data...")
    data = pd.read_csv(input_path, dtype={"HADM_ID": str, "SUBJECT_ID": str}).head(10)
    print(f"Dataset loaded. Total rows: {len(data)}")

    # ===============================================================================  
    # Ensure text fields are non-null and converted to strings
    print("Preprocessing text columns...")
    text_columns = ["TEXT", "1_t5_small2_SUMMARY", "3_bart_large_cnn_SUMMARY", "4_medical_summarization_SUMMARY"]
    for col in text_columns:
        data[col] = data[col].fillna("").astype(str)

    # ===============================================================================  
    # Initialize storage for embeddings
    embeddings = {"HADM_ID": data["HADM_ID"], "SUBJECT_ID": data["SUBJECT_ID"]}
    batch_size = 32  # Adjust based on GPU memory

    # ===============================================================================  
    # Generate embeddings for each text column
    for col in text_columns:
        print(f"Generating embeddings for '{col}' column...")
        emb_list = []
        for i in tqdm(range(0, len(data), batch_size), desc=f"Processing '{col}'"):
            batch_texts = data[col][i:i+batch_size].tolist()
            batch_embeddings = generate_embeddings_batch(batch_texts, tokenizer, model, device)
            emb_list.extend(batch_embeddings)
        
        # Store embeddings as a new column
        embeddings[f"EMB_{col}"] = emb_list

    # ===============================================================================  
    # Create new DataFrame with embeddings
    print("Creating final dataframe with embeddings...")
    result_df = pd.DataFrame(embeddings)

    # ===============================================================================  
    # Save to CSV
    print(f"Saving embeddings to {output_path}...")
    result_df.to_csv(output_path, index=False)
    print(f"Embeddings successfully saved to {output_path}.")


  from .autonotebook import tqdm as notebook_tqdm


Loading BioClinicalBERT model...
Model loaded.
Loading input data...
Dataset loaded. Total rows: 10
Preprocessing text columns...
Generating embeddings for 'TEXT' column...


Processing 'TEXT': 100%|██████████| 1/1 [00:08<00:00,  8.40s/it]


Generating embeddings for '1_t5_small2_SUMMARY' column...


Processing '1_t5_small2_SUMMARY': 100%|██████████| 1/1 [00:00<00:00,  1.04it/s]


Generating embeddings for '3_bart_large_cnn_SUMMARY' column...


Processing '3_bart_large_cnn_SUMMARY': 100%|██████████| 1/1 [00:00<00:00,  8.30it/s]


Generating embeddings for '4_medical_summarization_SUMMARY' column...


Processing '4_medical_summarization_SUMMARY': 100%|██████████| 1/1 [00:00<00:00,  3.41it/s]


Creating final dataframe with embeddings...
Saving embeddings to ../../../Data/unstructured/emb/merged_embeddings.csv...
Embeddings successfully saved to ../../../Data/unstructured/emb/merged_embeddings.csv.


In [2]:
result_df

Unnamed: 0,HADM_ID,SUBJECT_ID,EMB_TEXT,EMB_1_t5_small2_SUMMARY,EMB_3_bart_large_cnn_SUMMARY,EMB_4_medical_summarization_SUMMARY
0,100001.0,58526.0,"[-0.06878875941038132, -0.0074307480826973915,...","[0.19394980370998383, 0.28114089369773865, 0.0...","[-0.5926619172096252, 0.5007714629173279, -0.5...","[0.09199032932519913, 0.008501296862959862, 0...."
1,100003.0,54610.0,"[0.11134085804224014, 0.23517034947872162, -0....","[0.22276471555233002, 0.11342918127775192, 0.0...","[0.06475773453712463, 0.11684277653694153, -0....","[0.06370345503091812, -0.03363834321498871, -0..."
2,100006.0,9895.0,"[-0.18066149950027466, 0.18932631611824036, -0...","[0.04641161486506462, -0.1912076324224472, -0....","[-0.14331841468811035, 0.11323652416467667, -0...","[-0.11868789792060852, -0.09611806273460388, -..."
3,100007.0,,"[-0.1941411793231964, 0.12399176508188248, -0....","[0.06885648518800735, 0.3670652210712433, -0.5...","[0.14577215909957886, 0.1494465172290802, -0.4...","[0.2473163604736328, -0.0019324790919199586, -..."
4,100009.0,533.0,"[0.10732075572013855, -0.08329775929450989, -0...","[0.06423552334308624, -0.018912216648459435, -...","[-0.15307846665382385, 0.1643681824207306, -0....","[0.15557502210140228, 0.06813570111989975, -0...."
5,100010.0,55853.0,"[-0.17905670404434204, -0.04523336887359619, -...","[-0.008342467248439789, 0.19690968096256256, -...","[0.16566704213619232, -0.043629612773656845, -...","[-0.1264999806880951, 0.21646368503570557, -0...."
6,100011.0,87977.0,"[-0.078480064868927, -0.13545627892017365, -0....","[0.13416405022144318, 0.2285039722919464, 0.26...","[0.293035626411438, 0.23245438933372498, -0.40...","[-0.023307284340262413, -0.24253444373607635, ..."
7,100012.0,60039.0,"[0.019372183829545975, -0.13762210309505463, 0...","[0.2605943977832794, 0.0402924083173275, 0.428...","[0.11850736290216446, -0.15921162068843842, 0....","[0.10684802383184433, 0.042076583951711655, 0...."
8,100016.0,68591.0,"[0.21100114285945892, -0.03668534383177757, 0....","[0.5256306529045105, -0.1084759533405304, -0.0...","[0.38914477825164795, 0.04838335141539574, -0....","[0.09574245661497116, -0.08483774214982986, 0...."
9,100018.0,58128.0,"[0.178505077958107, 0.14203724265098572, -0.29...","[0.28787025809288025, -0.07443398982286453, -0...","[0.21366991102695465, 0.08555866777896881, -0....","[0.30290570855140686, -0.11699365824460983, -0..."
