<a href="https://colab.research.google.com/github/Shamshad-Gilani/BDS/blob/main/CDS_using_sapbert_and_Flan_T5_base.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Cell 0: Install dependencies
!pip install --upgrade pip
!pip install numpy==1.26.4
!pip install faiss-cpu==1.8.0
!pip install transformers==4.42.4 datasets==2.20.0 sentence-transformers==3.0.1 langchain==0.2.7 langchain-community==0.2.7 duckduckgo-search==7.5.5 peft==0.11.1

Collecting transformers==4.42.4
  Downloading transformers-4.42.4-py3-none-any.whl.metadata (43 kB)
Collecting datasets==2.20.0
  Downloading datasets-2.20.0-py3-none-any.whl.metadata (19 kB)
Collecting sentence-transformers==3.0.1
  Downloading sentence_transformers-3.0.1-py3-none-any.whl.metadata (10 kB)
Collecting langchain==0.2.7
  Downloading langchain-0.2.7-py3-none-any.whl.metadata (6.9 kB)
Collecting langchain-community==0.2.7
  Downloading langchain_community-0.2.7-py3-none-any.whl.metadata (2.5 kB)
Collecting duckduckgo-search==7.5.5
  Downloading duckduckgo_search-7.5.5-py3-none-any.whl.metadata (17 kB)
Collecting peft==0.11.1
  Downloading peft-0.11.1-py3-none-any.whl.metadata (13 kB)
Collecting tokenizers<0.20,>=0.19 (from transformers==4.42.4)
  Downloading tokenizers-0.19.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting pyarrow-hotfix (from datasets==2.20.0)
  Downloading pyarrow_hotfix-0.7-py3-none-any.whl.metadata (3.6 kB)
Colle

In [2]:
import pandas as pd
import os
import logging
from google.colab import files
from huggingface_hub import login, whoami
from google.colab import userdata
from transformers import BertForMaskedLM, BertTokenizer, DataCollatorForLanguageModeling, Trainer, TrainingArguments
from datasets import Dataset
from sentence_transformers import SentenceTransformer
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain_community.tools import DuckDuckGoSearchRun
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

In [3]:
# Configuration
config = {
    "model_name": "cambridgeltl/SapBERT-from-PubMedBERT-fulltext",
    "fine_tuned_path": "fine_tuned_sapbert_mlm",
    "sentence_transformer_path": "fine_tuned_sapbert_mlm_sentence",
    "vector_store_path": "faiss_index_mlm",
    "csv_path": "/content/icd11_data.csv",
    "input_file_path": "/content/icd_tab.txt",
    "num_epochs": 3,
    "batch_size": 16,
    "n_clusters": 50,
    "top_k": 5
}

# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [4]:
# Cell 1: Authenticate with Hugging Face
logger.info("Authenticating with Hugging Face...")
try:
    hf_token = userdata.get('HF_TOKEN')
    login(token=hf_token)
    logger.info(f"User info: {whoami(token=hf_token)}")
except Exception as e:
    logger.error(f"Error authenticating with Hugging Face: {e}")
    raise

In [5]:
# Cell 2: Load and process ICD-11 data
logger.info("Processing ICD-11 data...")
expected_columns = ['code', 'title', 'prompt']
if os.path.exists(config["csv_path"]):
    pipe_df = pd.read_csv(config["csv_path"], sep='|')
    logger.info("\nLoaded existing icd11_data.csv:")
    logger.info(pipe_df.head().to_string())
    pipe_df.drop_duplicates(inplace=True)
    pipe_df.drop(columns=['Sno', 'definition', 'inclusions', 'exclusions', 'content', 'name'], inplace=True, errors='ignore')
else:
    logger.info("Please upload your ICD-11 tab-delimited text file (expected name: icd_tab.txt).")
    try:
        uploaded = files.upload()
        uploaded_file = None
        for fname in uploaded.keys():
            if fname.lower() == 'icd_tab.txt':
                uploaded_file = fname
                break
        if not uploaded_file:
            raise FileNotFoundError("No file named 'icd_tab.txt' was uploaded.")
        with open(config["input_file_path"], 'wb') as f:
            f.write(uploaded[uploaded_file])
        df = pd.read_csv(config["input_file_path"], sep="\t")
        logger.info("Initial DataFrame head:")
        logger.info(df.head().to_string())
        if not all(col in df.columns for col in expected_columns):
            raise ValueError(f"Input file must contain {expected_columns}. Found: {list(df.columns)}")
        df = df.reset_index(drop=True)
        df.to_csv(config["csv_path"], sep='|', index=False)
        pipe_df = pd.read_csv(config["csv_path"], sep='|')
        pipe_df.drop_duplicates(inplace=True)
        pipe_df.drop(columns=['Sno', 'definition', 'inclusions', 'exclusions', 'content', 'name'], inplace=True, errors='ignore')
    except Exception as e:
        logger.error(f"Error uploading or processing file: {e}")
        raise
if 'prompt' not in pipe_df.columns or pipe_df['prompt'].isna().all():
    raise ValueError("The 'prompt' column is required and must contain non-empty values.")

Saving icd_tab.txt to icd_tab.txt


In [6]:
# prompt: want to view data in pipe.df

# Display the head of the pipe_df DataFrame
print("Head of pipe_df DataFrame:")
print(pipe_df.head().to_string())

# Display the info of the pipe_df DataFrame
print("\nInfo of pipe_df DataFrame:")
pipe_df.info()

# Display the descriptive statistics of the pipe_df DataFrame
print("\nDescription of pipe_df DataFrame:")
print(pipe_df.describe(include='all').to_string())

# Display the columns of the pipe_df DataFrame
print("\nColumns of pipe_df DataFrame:")
print(pipe_df.columns.tolist())

# Display the number of rows and columns
print(f"\nShape of pipe_df DataFrame: {pipe_df.shape}")

# Display the number of unique values in each column
print("\nNumber of unique values in each column:")
print(pipe_df.nunique())

# Display the value counts for the 'code' column (first 20)
if 'code' in pipe_df.columns:
    print("\nValue counts for 'code' column (top 20):")
    print(pipe_df['code'].value_counts().head(20))

# Display the value counts for the 'prompt' column (first 20)
if 'prompt' in pipe_df.columns:
    print("\nValue counts for 'prompt' column (top 20):")
    print(pipe_df['prompt'].value_counts().head(20))

Head of pipe_df DataFrame:
     code                                              title                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                 

In [7]:
# Cell 3: Fine-tune SapBERT with MLM
logger.info("Fine-tuning SapBERT with MLM...")
try:
    model = BertForMaskedLM.from_pretrained(config["model_name"])
    tokenizer = BertTokenizer.from_pretrained(config["model_name"])
    texts = pipe_df['prompt'].fillna('').astype(str).tolist()
    valid_texts = [text for text in texts if text.strip()]
    if not valid_texts:
        raise ValueError("No valid prompts found for fine-tuning.")
    encodings = tokenizer(valid_texts, truncation=True, padding=True, max_length=128)
    dataset = Dataset.from_dict(encodings)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)
    training_args = TrainingArguments(
        output_dir='./results',
        num_train_epochs=config["num_epochs"],
        per_device_train_batch_size=config["batch_size"],
        learning_rate=2e-5,
        warmup_steps=500,
        save_steps=500,
        save_total_limit=2,
        report_to="none"
    )
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        data_collator=data_collator
    )
    trainer.train()
    model.save_pretrained(config["fine_tuned_path"])
    tokenizer.save_pretrained(config["fine_tuned_path"])
    model = SentenceTransformer(config["fine_tuned_path"])
    model.save(config["sentence_transformer_path"])
except Exception as e:
    logger.error(f"Error in fine-tuning SapBERT: {e}")
    raise

config.json:   0%|          | 0.00/462 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of BertForMaskedLM were not initialized from the model checkpoint at cambridgeltl/SapBERT-from-PubMedBERT-fulltext and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tokenizer_config.json:   0%|          | 0.00/198 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Step,Training Loss


Some weights of BertModel were not initialized from the model checkpoint at fine_tuned_sapbert_mlm and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
# Cell 4: Create FAISS vector store
logger.info("Creating FAISS vector store...")
try:
    embeddings = HuggingFaceEmbeddings(model_name=config["sentence_transformer_path"])
    texts = pipe_df['prompt'].fillna('').astype(str).tolist()
    metadatas = [{"title": t} for t in pipe_df['title'].fillna('').astype(str)]
    valid_texts = []
    valid_metadatas = []
    for text, metadata in zip(texts, metadatas):
        if text.strip():
            valid_texts.append(text)
            valid_metadatas.append(metadata)
    if not valid_texts:
        raise ValueError("No valid prompts found for embedding.")
    vector_store = FAISS.from_texts(valid_texts, embeddings, metadatas=valid_metadatas)
    vector_store.save_local(config["vector_store_path"])
    logger.info("FAISS index created and saved successfully.")
except Exception as e:
    logger.error(f"Failed to create FAISS vector store: {e}")
    raise

  embeddings = HuggingFaceEmbeddings(model_name=config["sentence_transformer_path"])


In [None]:
!pip install -U langchain-community



In [9]:
# Cell 5: Set up LangChain pipeline
from transformers import T5ForConditionalGeneration, T5Tokenizer, pipeline
import torch  # Added import for torch

logger.info("Setting up LangChain pipeline...")
try:
    # Load the model and tokenizer
    model_id = "google/flan-t5-base"
    tokenizer = T5Tokenizer.from_pretrained(model_id)
    model = T5ForConditionalGeneration.from_pretrained(model_id)

    # Create a pipeline for text generation
    text_generation_pipeline = pipeline(
        task="text2text-generation",
        model=model,
        tokenizer=tokenizer,
        max_length=200,  # Equivalent to max_new_tokens
        temperature=0.7,
        top_p=0.9,
        device=0 if torch.cuda.is_available() else -1  # Use GPU if available, else CPU
    )

    # Wrap the pipeline in HuggingFacePipeline
    llm = HuggingFacePipeline(pipeline=text_generation_pipeline)
except Exception as e:
    logger.error(f"Error loading model: {e}")
    raise

prompt_template = PromptTemplate(
    input_variables=["user_input", "icd_context", "internet_data"],
    template="""
    You are a medical assistant, not a doctor. Suggest possible medical issues based on user input, ICD-11 matches, and internet data. Do not diagnose. Use clear language. Include: "Doctor advice is sought."
    User Input: {user_input}
    ICD-11 Matches: {icd_context}
    Internet Data: {internet_data}
    Response Format:
    - Possible Issues: [List issues.]
    - Disclaimer: Doctor advice is sought.
    """
)
search = DuckDuckGoSearchRun()

tokenizer_config.json: 0.00B [00:00, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json: 0.00B [00:00, ?B/s]

model.safetensors:   0%|          | 0.00/990M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

  llm = HuggingFacePipeline(pipeline=text_generation_pipeline)


In [10]:
# Cell 6: Define query function
def run_query(user_input, vector_store_path=config["vector_store_path"], model_path=config["sentence_transformer_path"]):
    logger.info(f"Running query: {user_input}")
    try:
        embeddings = HuggingFaceEmbeddings(model_name=model_path)
        vector_store = FAISS.load_local(vector_store_path, embeddings, allow_dangerous_deserialization=True)
        icd_matches = vector_store.similarity_search_with_score(user_input, k=config["top_k"])
        icd_context = "\n".join([f"ICD-11 Code: {m.metadata['title']}, Prompt: {m.page_content}, Score: {score}" for m, score in icd_matches])
    except Exception as e:
        logger.error(f"Error in FAISS retrieval: {e}")
        icd_context = "No ICD-11 matches available."
    try:
        internet_data = search.run(f"{user_input} ICD-11 medical conditions", max_results=3)
    except Exception as e:
        logger.error(f"DuckDuckGo search failed: {e}. Using empty internet data.")
        internet_data = "No internet data available."
    try:
        response = llm(prompt_template.format(user_input=user_input, icd_context=icd_context, internet_data=internet_data))
        return response
    except Exception as e:
        logger.error(f"LLM generation failed: {e}")
        return "Error generating response."

In [11]:
# Cell 7: Run example query
user_input = "Patient has anxiety and panic attacks"
response = run_query(user_input)
logger.info("Response:\n" + response)

  ddgs_gen = ddgs.text(
ERROR:__main__:DuckDuckGo search failed: https://html.duckduckgo.com/html 202 Ratelimit. Using empty internet data.
  response = llm(prompt_template.format(user_input=user_input, icd_context=icd_context, internet_data=internet_data))
Token indices sequence length is longer than the specified maximum sequence length for this model (7297 > 512). Running this sequence through the model will result in indexing errors


In [12]:
# Cell 8: Unsupervised testing (clustering)
print("Starting unsupervised testing (clustering)...")
logger.info("Performing unsupervised testing (clustering)...")
try:
    embeddings = HuggingFaceEmbeddings(model_name=config["sentence_transformer_path"])
    logger.info(f"Loading FAISS index from {config['vector_store_path']}")
    print(f"Loading FAISS index from {config['vector_store_path']}")
    vector_store = FAISS.load_local(config["vector_store_path"], embeddings, allow_dangerous_deserialization=True)
    logger.info(f"FAISS index loaded, total embeddings: {vector_store.index.ntotal}")
    print(f"FAISS index loaded, total embeddings: {vector_store.index.ntotal}")
    if vector_store.index.ntotal < config["n_clusters"]:
        error_msg = f"Too few embeddings ({vector_store.index.ntotal}) for {config['n_clusters']} clusters"
        logger.error(error_msg)
        print(error_msg)
        raise ValueError(error_msg)
    embeddings_array = vector_store.index.reconstruct_n(0, vector_store.index.ntotal)
    logger.info(f"Embeddings array shape: {embeddings_array.shape}")
    print(f"Embeddings array shape: {embeddings_array.shape}")
    kmeans = KMeans(n_clusters=config["n_clusters"], random_state=42)
    labels = kmeans.fit_predict(embeddings_array)
    score = silhouette_score(embeddings_array, labels)
    logger.info(f"MLM - Silhouette Score: {score:.4f}")
    print(f"MLM - Silhouette Score: {score:.4f}")
    with open("/content/clustering_output.txt", "w") as f:
        f.write(f"MLM - Silhouette Score: {score:.4f}\n")
    logger.info("Clustering output saved to /content/clustering_output.txt")
    print("Clustering output saved to /content/clustering_output.txt")
except Exception as e:
    logger.error(f"Clustering failed: {str(e)}")
    print(f"Clustering failed: {str(e)}")
    raise

Starting unsupervised testing (clustering)...
Loading FAISS index from faiss_index_mlm
FAISS index loaded, total embeddings: 423
Embeddings array shape: (423, 768)
MLM - Silhouette Score: 0.1172
Clustering output saved to /content/clustering_output.txt


In [13]:
# Cell 9: Semi-supervised testing (Top-5 accuracy)
print("Starting semi-supervised testing...")
logger.info("Performing semi-supervised testing...")
try:
    test_queries = []
    for _, row in pipe_df.sample(min(50, len(pipe_df))).iterrows():
        prompt = row['prompt'].lower()
        query = f"Patient has {prompt}"  # Simplified to avoid colon parsing issues
        test_queries.append({"input": query, "expected_code": row['title']})
    logger.info(f"Generated {len(test_queries)} test queries")
    print(f"Generated {len(test_queries)} test queries")
    if not test_queries:
        error_msg = "No test queries generated"
        logger.error(error_msg)
        print(error_msg)
        raise ValueError(error_msg)

    def evaluate_model(model_path, vector_store_path, test_inputs, top_k=config["top_k"]):
        logger.info(f"Evaluating model with top-{top_k} accuracy...")
        print(f"Evaluating model with top-{top_k} accuracy...")
        embeddings = HuggingFaceEmbeddings(model_name=model_path)
        logger.info(f"Loading FAISS index from {vector_store_path}")
        print(f"Loading FAISS index from {vector_store_path}")
        vector_store = FAISS.load_local(vector_store_path, embeddings, allow_dangerous_deserialization=True)
        logger.info(f"FAISS index loaded, total embeddings: {vector_store.index.ntotal}")
        print(f"FAISS index loaded, total embeddings: {vector_store.index.ntotal}")
        results = []
        for test in test_inputs:
            icd_matches = vector_store.similarity_search_with_score(test['input'], k=top_k)
            predicted_codes = [m.metadata['title'] for m, score in icd_matches]
            is_correct = test['expected_code'] in predicted_codes
            results.append(is_correct)
            logger.debug(f"Query: {test['input']}, Correct: {is_correct}, Predicted: {predicted_codes}")
        accuracy = sum(results) / len(results)
        logger.info(f"MLM - Top-{top_k} Accuracy: {accuracy:.4f}")
        print(f"MLM - Top-{top_k} Accuracy: {accuracy:.4f}")
        with open("/content/evaluation_output.txt", "w") as f:
            f.write(f"MLM - Top-{top_k} Accuracy: {accuracy:.4f}\n")
        logger.info("Evaluation output saved to /content/evaluation_output.txt")
        print("Evaluation output saved to /content/evaluation_output.txt")
        return accuracy

    accuracy = evaluate_model(config["sentence_transformer_path"], config["vector_store_path"], test_queries)
except Exception as e:
    logger.error(f"Evaluation failed: {str(e)}")
    print(f"Evaluation failed: {str(e)}")
    raise

Starting semi-supervised testing...
Generated 50 test queries
Evaluating model with top-5 accuracy...
Loading FAISS index from faiss_index_mlm
FAISS index loaded, total embeddings: 423
MLM - Top-5 Accuracy: 1.0000
Evaluation output saved to /content/evaluation_output.txt


In [14]:
# prompt: want to view data in clustering output and evaluation output files

!cat /content/clustering_output.txt
!cat /content/evaluation_output.txt

MLM - Silhouette Score: 0.1172
MLM - Top-5 Accuracy: 1.0000
