<a href="https://colab.research.google.com/github/SarahPendhari/CardioCare-heart-llm-qna/blob/main/LLMfinal.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Set CUDA environment variable for better error reporting
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

# Install required packages
!pip install transformers datasets accelerate -q

import torch
import re
from datasets import load_dataset, concatenate_datasets
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    Trainer,
    TrainingArguments,
    DataCollatorForSeq2Seq
)

# Load the dataset
print("Loading dataset...")
dataset = load_dataset("GBaker/MedQA-USMLE-4-options")
print(f"Original dataset size: Train={len(dataset['train'])}, Test={len(dataset['test'])}")

# Function to identify heart/cardiovascular related questions
def is_heart_related(example):
    """Check if a question is related to cardiology/heart"""
    heart_keywords = [
        'heart', 'cardiac', 'cardio', 'cardiovascular', 'myocardial', 'coronary',
        'arrhythmia', 'atrial', 'ventricular', 'aorta', 'aortic', 'angina',
        'cardiomyopathy', 'pericardial', 'pericardium', 'endocarditis',
        'atherosclerosis', 'ecg', 'electrocardiogram', 'ekg', 'tachycardia',
        'bradycardia', 'fibrillation', 'palpitation', 'hypertension', 'hypertensive',
        'thrombosis', 'embolism', 'stroke', 'ischemia', 'ischemic', 'infarction',
        'stent', 'angioplasty', 'bypass', 'valve', 'mitral', 'tricuspid', 'pulmonary hypertension',
        'congestive heart failure', 'chf', 'murmur', 'pulse', 'blood pressure', 'circulation',
        'vascular', 'vasculitis', 'chest pain', 'dyspnea', 'syncope', 'edema'
    ]

    # Check if any keyword is in the question
    question_lower = example['question'].lower()
    for keyword in heart_keywords:
        if re.search(r'\b' + re.escape(keyword) + r'\b', question_lower):
            return True

    # Also check in the metamap_phrases if available
    if 'metamap_phrases' in example:
        phrases = ' '.join(example['metamap_phrases']).lower()
        for keyword in heart_keywords:
            if re.search(r'\b' + re.escape(keyword) + r'\b', phrases):
                return True

    return False

# Filter the dataset to only include heart-related questions
print("Filtering dataset for heart-related content...")

# Apply the filter
train_heart_filtered = dataset["train"].filter(is_heart_related)
test_heart_filtered = dataset["test"].filter(is_heart_related)

# Create a new dataset dictionary with only heart-related questions
heart_dataset = {"train": train_heart_filtered, "test": test_heart_filtered}
print(f"Heart dataset size: Train={len(heart_dataset['train'])}, Test={len(heart_dataset['test'])}")

# If the heart dataset is too small, augment with some general questions
min_training_size = 1000
if len(heart_dataset['train']) < min_training_size:
    num_additional = min_training_size - len(heart_dataset['train'])
    print(f"Heart dataset is small, adding {num_additional} general questions...")
    # Add some general questions to ensure enough training data
    non_heart_dataset = dataset["train"].filter(lambda x: not is_heart_related(x))
    additional_examples = non_heart_dataset.select(range(min(num_additional, len(non_heart_dataset))))
    # Combine the datasets
    heart_dataset["train"] = concatenate_datasets([heart_dataset["train"], additional_examples])
    print(f"Augmented dataset size: Train={len(heart_dataset['train'])}")

# Check GPU availability
print("Checking GPU availability...")
if torch.cuda.is_available():
    try:
        test_tensor = torch.tensor([1.0, 2.0, 3.0], device="cuda")
        print("GPU test successful:", test_tensor.device)
        device = torch.device("cuda")
    except RuntimeError as e:
        print(f"GPU error: {e}")
        print("Falling back to CPU")
        device = torch.device("cpu")
else:
    print("No GPU available, using CPU")
    device = torch.device("cpu")

print(f"Using device: {device}")

# Load T5 model and tokenizer
print("Loading model and tokenizer...")
model_name = "t5-base"  # Using t5-base for better performance
tokenizer = T5Tokenizer.from_pretrained(model_name)

# Load model with proper error handling
try:
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    model.to(device)
    print("Model loaded and moved to device successfully")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Trying with CPU only...")
    device = torch.device("cpu")
    model = T5ForConditionalGeneration.from_pretrained(model_name)
    model.to(device)

# Format the dataset for T5 with cardiology prefix
print("Formatting dataset...")
def format_cardio_medqa(example):
    # Add a cardiology-specific prefix
    input_text = f"answer cardiology medical question: {example['question']} Options: "
    for key, value in example['options'].items():
        input_text += f"{key}: {value}. "

    # The target is the letter answer (A, B, C, D)
    target_text = example['answer_idx']

    return {
        "input_text": input_text,
        "target_text": target_text
    }

# Apply formatting to dataset
formatted_dataset = {}
for split, data in heart_dataset.items():
    formatted_dataset[split] = data.map(format_cardio_medqa)

# Print a sample formatted entry
if len(formatted_dataset["train"]) > 0:
    print("\nSample formatted entry:")
    print(formatted_dataset["train"][0])
else:
    print("\nNo heart-related questions found in the dataset.")

# Tokenize the dataset properly for T5
print("Tokenizing dataset...")
def tokenize_function(examples):
    inputs = tokenizer(
        examples["input_text"],
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

    # Tokenize the targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples["target_text"],
            padding="max_length",
            truncation=True,
            max_length=8,
            return_tensors="pt"
        )

    # Replace padding token id's with -100 so they're not included in loss computation
    labels_with_ignore = labels["input_ids"].clone()
    labels_with_ignore[labels_with_ignore == tokenizer.pad_token_id] = -100

    inputs["labels"] = labels_with_ignore
    return inputs

# Apply tokenization
tokenized_dataset = {}
for split, data in formatted_dataset.items():
    tokenized_dataset[split] = data.map(
        tokenize_function,
        batched=True,
        batch_size=8,
        remove_columns=formatted_dataset[split].column_names
    )

print("\nTokenized dataset structure:")
for split, data in tokenized_dataset.items():
    print(f"{split}: {data}")

# Check if we actually have enough data to train
if len(tokenized_dataset["train"]) < 100:
    print("Not enough heart-related data to train effectively. Exiting.")
    exit()

# Create data collator with padding
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    padding=True
)

# Define training arguments with more epochs for specialized knowledge
print("Setting up training configuration...")
training_args = TrainingArguments(
    output_dir="./heart_model_results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=10,              # More epochs for specialized knowledge
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    report_to="none",
    fp16=torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 7,
)

# Initialize the Trainer
print("Initializing trainer...")
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Free up memory before training
if device.type == "cuda":
    torch.cuda.empty_cache()

# Start training with error handling
print("Starting training...")
try:
    trainer.train()
    print("Training completed successfully!")

    # Save the fine-tuned model
    model.save_pretrained("./heart-medqa-t5-model")
    tokenizer.save_pretrained("./heart-medqa-t5-model")

    # Test the model on a sample heart question
    def generate_answer(question, options):
        input_text = f"answer cardiology medical question: {question} Options: "
        for key, value in options.items():
            input_text += f"{key}: {value}. "

        inputs = tokenizer(input_text, return_tensors="pt").to(device)
        outputs = model.generate(**inputs, max_length=8)
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

        return answer

    # Find a heart-related example from the test set
    if len(heart_dataset["test"]) > 0:
        sample = heart_dataset["test"][0]
        print("\nSample heart question:")
        print(sample["question"])
        print("Options:", sample["options"])
        print("Correct answer:", sample["answer_idx"])

        predicted = generate_answer(sample["question"], sample["options"])
        print("Model prediction:", predicted)
    else:
        print("No heart-related questions found in the test set for evaluation.")

except Exception as e:
    print(f"Training error: {e}")
    print("Troubleshooting recommendations:")
    print("1. Try using a smaller model like t5-small")
    print("2. Further reduce batch size")
    print("3. Consider training on CPU if GPU memory is limited")

Loading dataset...
Original dataset size: Train=10178, Test=1273
Filtering dataset for heart-related content...
Heart dataset size: Train=5306, Test=674
Checking GPU availability...
GPU test successful: cuda:0
Using device: cuda
Loading model and tokenizer...
Model loaded and moved to device successfully
Formatting dataset...

Sample formatted entry:
{'question': 'A 23-year-old pregnant woman at 22 weeks gestation presents with burning upon urination. She states it started 1 day ago and has been worsening despite drinking more water and taking cranberry extract. She otherwise feels well and is followed by a doctor for her pregnancy. Her temperature is 97.7°F (36.5°C), blood pressure is 122/77 mmHg, pulse is 80/min, respirations are 19/min, and oxygen saturation is 98% on room air. Physical exam is notable for an absence of costovertebral angle tenderness and a gravid uterus. Which of the following is the best treatment for this patient?', 'answer': 'Nitrofurantoin', 'options': {'A': 'A

  trainer = Trainer(


Epoch,Training Loss,Validation Loss
1,0.7135,0.725118
2,0.7204,0.700507
3,0.7089,0.68714
4,0.707,0.700796
5,0.7271,0.711328
6,0.7627,0.707684
7,0.7589,0.706322
8,0.7723,0.705032
9,0.7552,0.705018
10,0.7539,0.705054


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


Training completed successfully!

Sample heart question:
Two weeks after undergoing an emergency cardiac catherization with stenting for unstable angina pectoris, a 61-year-old man has decreased urinary output and malaise. He has type 2 diabetes mellitus and osteoarthritis of the hips. Prior to admission, his medications were insulin and naproxen. He was also started on aspirin, clopidogrel, and metoprolol after the coronary intervention. His temperature is 38°C (100.4°F), pulse is 93/min, and blood pressure is 125/85 mm Hg. Examination shows mottled, reticulated purplish discoloration of the feet. Laboratory studies show:
Hemoglobin count 14 g/dL
Leukocyte count 16,400/mm3
Segmented neutrophils 56%
Eosinophils 11%
Lymphocytes 31%
Monocytes 2%
Platelet count 260,000/mm3
Erythrocyte sedimentation rate 68 mm/h
Serum
Urea nitrogen 25 mg/dL
Creatinine 4.2 mg/dL
Renal biopsy shows intravascular spindle-shaped vacuoles. Which of the following is the most likely cause of this patient's sympto

In [None]:
model.save_pretrained("./heart-medqa-t5-model")
tokenizer.save_pretrained("./heart-medqa-t5-model")

('./heart-medqa-t5-model/tokenizer_config.json',
 './heart-medqa-t5-model/special_tokens_map.json',
 './heart-medqa-t5-model/spiece.model',
 './heart-medqa-t5-model/added_tokens.json')

In [None]:
!pip install gradio -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.9/46.9 MB[0m [31m47.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m322.2/322.2 kB[0m [31m27.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m95.2/95.2 kB[0m [31m8.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.4/11.4 MB[0m [31m119.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.0/72.0 kB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.3/62.3 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import gradio as gr
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

# Load the fine-tuned model and tokenizer
model_path = "./heart-medqa-t5-model"  # Path to your saved model
tokenizer = T5Tokenizer.from_pretrained(model_path)
model = T5ForConditionalGeneration.from_pretrained(model_path)

# Check GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Additional knowledge base for common cardiology responses
cardio_knowledge = {
    "heart attack": "A heart attack (myocardial infarction) occurs when blood flow to part of the heart is blocked, causing damage to heart muscle. Common symptoms include chest pain, shortness of breath, and pain radiating to the arm, jaw or back.",
    "arrhythmia": "Arrhythmias are irregular heartbeats caused by problems with the heart's electrical system. They may feel like palpitations, fluttering, or a racing heartbeat.",
    "hypertension": "Hypertension (high blood pressure) is a condition where the force of blood against artery walls is consistently too high. It's often called a 'silent killer' because it frequently has no symptoms.",
    "murmur": "Heart murmurs are unusual sounds heard during a heartbeat cycle, like whooshing or swishing noises. Some are harmless (innocent), while others indicate heart problems.",
    "ecg": "An electrocardiogram (ECG or EKG) records the electrical signals in your heart. It's used to detect heart problems and monitor heart health.",
}

def answer_cardio_question(question):
    """Generate a response to an open-ended cardiology question"""

    # Check if question is empty
    if not question.strip():
        return "Please enter a cardiology-related question."

    # Check if the question matches our knowledge base first
    for key, response in cardio_knowledge.items():
        if key in question.lower():
            return response

    # For other questions, we'll modify our approach with the T5 model
    # Since the model was trained for multiple choice, we'll create generic options
    generic_options = {
        "A": "This represents a normal cardiac finding.",
        "B": "This indicates a pathological cardiac condition.",
        "C": "This requires immediate medical intervention.",
        "D": "This is a benign variation requiring monitoring."
    }

    # Format input as the model expects, but with our generic options
    input_text = f"answer cardiology medical question: {question} Options: "
    for key, value in generic_options.items():
        input_text += f"{key}: {value}. "

    # Generate the answer
    try:
        inputs = tokenizer(input_text, return_tensors="pt").to(device)
        outputs = model.generate(**inputs, max_length=8)
        answer_idx = tokenizer.decode(outputs[0], skip_special_tokens=True)

        # Based on which option the model chose, provide a relevant response
        if answer_idx == "A":
            response = "This appears to be a normal cardiac finding. Regular monitoring is advised."
        elif answer_idx == "B":
            response = "This likely indicates a pathological cardiac condition. A consultation with a cardiologist is recommended."
        elif answer_idx == "C":
            response = "This condition may require prompt medical attention. Please consult with a healthcare provider soon."
        elif answer_idx == "D":
            response = "This appears to be a benign variation, but should be monitored by a healthcare professional."
        else:
            response = "I'm not able to provide a specific answer to this question. Please consult with a healthcare professional."

        # Add medical disclaimer
        return response + "\n\nNote: This information is provided by an AI assistant and should not replace professional medical advice."
    except Exception as e:
        return f"I encountered an error while processing your question. Please try rephrasing it or consult with a healthcare professional."

# Create the chat interface with a simpler approach
with gr.Blocks(title="Cardiology Medical Chatbot") as demo:
    gr.Markdown("# Cardiology Medical Assistant")
    gr.Markdown("""
    This chatbot can answer questions about heart health and cardiology.

    **Example questions you can ask:**
    - What are the symptoms of a heart attack?
    - What is atrial fibrillation?
    - How is heart failure treated?
    - What does an elevated troponin level indicate?
    - What are the risk factors for coronary artery disease?

    **Disclaimer:** This AI assistant provides general information only and should not replace professional medical advice.
    """)

    # Create a simple chatbot interface
    chatbot = gr.Chatbot()

    # Input area with submit button
    with gr.Row():
        input_text = gr.Textbox(
            placeholder="Enter your cardiology question...",
            lines=2,
            scale=4
        )
        submit_button = gr.Button("Submit", scale=1)

    # Clear chat button
    clear_button = gr.Button("Clear Chat")

    # Define the chat function
    def chat(user_message, history):
        bot_response = answer_cardio_question(user_message)
        history.append((user_message, bot_response))
        return "", history

    # Connect components
    submit_button.click(
        chat,
        inputs=[input_text, chatbot],
        outputs=[input_text, chatbot]
    )

    input_text.submit(
        chat,
        inputs=[input_text, chatbot],
        outputs=[input_text, chatbot]
    )

    # Clear chat history
    def clear_history():
        return []

    clear_button.click(
        clear_history,
        outputs=[chatbot]
    )

# Launch the app
if __name__ == "__main__":
    demo.launch()

  chatbot = gr.Chatbot()


Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://86bc41644d06ee0c6f.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
