# ClinicalBERT Embeddings for TCGA Pathology Reports

This project generates **case-level text embeddings** from TCGA pathology reports using **Bio_ClinicalBERT**.

## Overview

- Model: Bio_ClinicalBERT
- Input: OCR-parsed TCGA pathology reports (`.p` files)
- Output: 768-dimensional embedding per TCGA case
- Pooling strategy:
  - Mean pooling over tokens
  - Token-countâ€“weighted averaging across report pages

## Method

1. Extract LINE-level text from each pathology report page
2. Encode text using ClinicalBERT
3. Compute page embeddings via mean pooling
4. Aggregate page embeddings into a case-level embedding weighted by token count

In [None]:
import os
import pickle
import torch
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModel

# ======================
# Configuration
# ======================
MODEL_NAME = "emilyalsentzer/Bio_ClinicalBERT"
DATA_DIR = "/content/aws_response/"
OUTPUT_CSV = "text_embeddings.csv"
MAX_LENGTH = 512

# ======================
# Load model
# ======================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(device)
model.eval()

# ======================
# Load files
# ======================
file_list = [f for f in os.listdir(DATA_DIR) if f.endswith(".p")]
case_ids = sorted(list(set(f[:12] for f in file_list)))

results = []

# ======================
# Process cases
# ======================
for case_id in tqdm(case_ids, desc="Processing cases"):
    case_files = sorted([f for f in file_list if f.startswith(case_id)])

    weighted_sum = None
    weight_total = 0.0

    for fname in case_files:
        with open(os.path.join(DATA_DIR, fname), "rb") as f:
            parsed_data = pickle.load(f)

        blocks = parsed_data.get("Blocks", [])
        page_lines = [
            b["Text"] for b in blocks if b.get("BlockType") == "LINE"
        ]
        page_text = "\n".join(page_lines).strip()

        if not page_text:
            continue

        inputs = tokenizer(
            page_text,
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=MAX_LENGTH,
        )

        page_weight = float(inputs["attention_mask"].sum().item())
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
            page_emb = outputs.last_hidden_state.mean(dim=1).squeeze(0)

        if weighted_sum is None:
            weighted_sum = page_emb * page_weight
        else:
            weighted_sum += page_emb * page_weight

        weight_total += page_weight

    if weighted_sum is None or weight_total == 0:
        continue

    case_embedding = weighted_sum / weight_total
    results.append([case_id] + case_embedding.cpu().tolist())

# ======================
# Save results
# ======================
columns = ["case_id"] + [f"DL_{i}" for i in range(768)]
df = pd.DataFrame(results, columns=columns)
df.to_csv(OUTPUT_CSV, index=False)

print(f"Saved embeddings to {OUTPUT_CSV}")