# Medical Symptom Extraction with BioGPT - Single Symptom Version

Simple symptom extraction using Microsoft's BioGPT model. Takes patient messages and identifies ONE primary symptom.

## 1. Import Libraries

In [2]:
import torch
from transformers import BioGptTokenizer, BioGptForCausalLM
import re
import json
from typing import List, Dict, Tuple
import warnings
warnings.filterwarnings('ignore')

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

  from .autonotebook import tqdm as notebook_tqdm


Using device: cpu


## 2. Load BioGPT Model

In [3]:
# Load BioGPT model
print("Loading BioGPT model...")
tokenizer = BioGptTokenizer.from_pretrained("microsoft/biogpt")
model = BioGptForCausalLM.from_pretrained("microsoft/biogpt")
model.to(device)
model.eval()
print("Model loaded successfully!")

Loading BioGPT model...
Model loaded successfully!


## 3. Symptom Extraction Function

In [4]:
def extract_symptoms(patient_message: str, 
                     num_beams: int = 3,
                     max_new_tokens: int = 50) -> Dict:
    """
    Extract ONE primary symptom from patient message using BioGPT.
    
    Args:
        patient_message: The patient's description of their condition
        num_beams: Number of beams for beam search
        max_new_tokens: Maximum number of new tokens to generate
    
    Returns:
        Dictionary containing the single extracted symptom with confidence score
    """
    
    # BioGPT works better with medical/clinical style prompts
    # Modified to ask for primary symptom
    prompt = f"The patient presents with {patient_message}. The primary symptom is"
    
    # Tokenize input
    inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Generate response with BioGPT
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,  # Use max_new_tokens instead of max_length
            num_beams=num_beams,
            temperature=0.8,
            do_sample=True,  # Enable sampling for better diversity
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            output_scores=True,
            return_dict_in_generate=True
        )
    
    # Decode the generated text
    generated_text = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True)
    
    # Extract only the generated part (after the prompt)
    if prompt in generated_text:
        generated_part = generated_text.split(prompt)[1].strip()
    else:
        # Fallback if exact prompt not found
        generated_part = generated_text[len(prompt):].strip()
    
    # Parse ONLY THE FIRST symptom from generated text
    primary_symptom = parse_primary_symptom(generated_part)
    
    # Calculate confidence score for the single symptom
    symptom_with_confidence = calculate_confidence_score(primary_symptom, outputs)
    
    return {
        "original_message": patient_message,
        "biogpt_output": generated_part,
        "extracted_symptom": symptom_with_confidence  # Now returns single symptom dict
    }

In [5]:
def parse_primary_symptom(generated_text: str) -> str:
    """
    Parse ONLY THE FIRST/PRIMARY symptom from BioGPT generated text.
    
    Args:
        generated_text: Text generated by BioGPT
    
    Returns:
        Single primary symptom string
    """
    
    if not generated_text:
        return ""
    
    # Clean and split the text
    text = generated_text.strip()
    
    # Remove periods at the end
    text = text.rstrip('.')
    
    # Take only the FIRST symptom before any delimiter
    # Split by various delimiters and take the first part
    delimiters = [',', ';', ' and ', ' with ', ' including', ' such as']
    
    primary_symptom = text
    for delimiter in delimiters:
        if delimiter in text:
            # Take only the first part before the delimiter
            primary_symptom = text.split(delimiter)[0].strip()
            break
    
    # Clean up the symptom
    primary_symptom = primary_symptom.strip().rstrip('.,;:')
    
    # If it's a numbered list, extract the first item
    numbered_pattern = r'^\d+\.\s*(.+)'
    numbered_match = re.match(numbered_pattern, primary_symptom)
    if numbered_match:
        primary_symptom = numbered_match.group(1).strip()
    
    # Ensure it's not too long (symptoms should be concise)
    if len(primary_symptom.split()) > 5:
        # Take only the first few words
        words = primary_symptom.split()[:4]
        primary_symptom = ' '.join(words)
    
    return primary_symptom.lower()

In [6]:
def calculate_confidence_score(symptom: str, outputs) -> Dict:
    """
    Calculate confidence score for the single extracted symptom.
    
    Args:
        symptom: The extracted symptom
        outputs: Model generation outputs
    
    Returns:
        Dictionary with symptom and confidence score
    """
    
    # Get the sequence score from beam search as base confidence
    if hasattr(outputs, 'sequences_scores') and outputs.sequences_scores is not None:
        # Convert log probability to probability
        confidence = torch.exp(outputs.sequences_scores[0]).item()
    else:
        confidence = 0.75  # Default confidence if scores not available
    
    # Adjust confidence based on symptom quality
    if symptom and len(symptom) > 2:
        # Boost confidence for clear symptoms
        if len(symptom.split()) <= 3:
            confidence = min(0.95, confidence * 1.1)
    else:
        confidence = 0.0
    
    return {
        "symptom": symptom if symptom else "unspecified",
        "confidence": round(confidence, 3)
    }

In [10]:
extract_symptoms("My ankle feels swollen")

{'original_message': 'My ankle feels swollen',
 'biogpt_output': 'pain.',
 'extracted_symptom': {'symptom': 'pain', 'confidence': 0.304}}

## 4. Test the Single Symptom Extraction

In [None]:
# Test cases
test_cases = [
    {"input": "I have been feeling really tired recently", "expected": "fatigue"},
    {"input": "My head is pounding and throbbing with pain", "expected": "headache"},
    {"input": "I feel like I'm going to throw up", "expected": "nausea"},
    {"input": "My ankle feels swollen", "expected": "ankle swelling"},
    {"input": "i have been coughing since last night", "expected": "cough"},
    {"input": "My stomach hurts really bad", "expected": "abdominal pain"},
    {"input": "I feel dizzy and the room is spinning", "expected": "dizziness"},
    {"input": "I have red bumps all over my skin", "expected": "rash"},
    {"input": "I feel extremely tired and have no energy", "expected": "fatigue"},
    {"input": "My throat is really sore when I swallow", "expected": "sore throat"},
    {"input": "My chest feels tight and it's hard to breathe", "expected": "shortness of breath"},
    {"input": "I've been having trouble sleeping", "expected": "insomnia"},
    {"input": "My joints are swollen and painful", "expected": "joint pain"},
    {"input": "i feel so weak", "expected": "weakness"},
    {"input": "I've been throwing up multiple times today", "expected": "vomiting"},
    {"input": "My muscles ache all over", "expected": "myalgia"},
    {"input": "I can't sleep and stay awake all night", "expected": "insomnia"},
    {"input": "i have lost 20kgs in the past few months", "expected": "recent weight loss"},
    {"input": "I have sharp pain in my lower back", "expected": "back pain"},
    {"input": "my skin is dry and peeling from the weather", "expected": "skin peeling"}
]

In [12]:
def check_match(extracted: str, expected: str) -> Tuple[bool, str]:
    """
    Check if extracted symptom matches expected, allowing for variations.
    
    Returns:
        Tuple of (is_match, reason)
    """
    if not extracted or extracted == "unspecified":
        return False, f"No symptom extracted. Expected: '{expected}'"
    
    extracted_lower = extracted.lower().strip()
    expected_lower = expected.lower().strip()
    
    # Exact match
    if extracted_lower == expected_lower:
        return True, f"Exact match: '{extracted}'"
    
    # Partial match (one contains the other)
    if expected_lower in extracted_lower or extracted_lower in expected_lower:
        return True, f"Partial match: '{extracted}' contains/contained in '{expected}'"
    
    # Check for common synonyms
    synonym_map = {
        "diarrhea": ["watery diarrhea", "loose stools", "watery stools"],
        "nausea": ["vomiting", "emesis", "feeling sick"],
        "vomiting": ["regurgitation", "throwing up", "emesis"],
        "rhinorrhea": ["runny nose", "nasal discharge"],
        "myalgia": ["muscle pain", "muscle aches"],
        "shortness of breath": ["dyspnea", "breathing difficulty", "respiratory distress"]
    }
    
    for key, synonyms in synonym_map.items():
        if expected_lower == key.lower() or expected_lower in [s.lower() for s in synonyms]:
            if extracted_lower == key.lower() or extracted_lower in [s.lower() for s in synonyms]:
                return True, f"Synonym match: '{extracted}' ‚âà '{expected}'"
    
    return False, f"No match. Expected: '{expected}', Got: '{extracted}'"

In [13]:
def run_tests():
    """
    Run all test cases and display results.
    """
    print("="*80)
    print("TESTING SINGLE SYMPTOM EXTRACTION")
    print("="*80)
    print(f"\nRunning {len(test_cases)} test cases...\n")
    
    results = []
    correct_count = 0
    
    for i, test_case in enumerate(test_cases, 1):
        print(f"\nTest Case {i}/{len(test_cases)}")
        print("-" * 40)
        print(f"Input: \"{test_case['input']}\"")
        print(f"Expected: {test_case['expected']}")
        
        # Extract symptom
        result = extract_symptoms(test_case['input'])
        
        # Get the single extracted symptom
        extracted = result['extracted_symptom']['symptom']
        confidence = result['extracted_symptom']['confidence']
        
        print(f"BioGPT Output: \"{result['biogpt_output']}\"")
        print(f"Extracted Symptom: {extracted}")
        print(f"Confidence: {confidence:.3f}")
        
        # Check if correct
        is_correct, reason = check_match(extracted, test_case['expected'])
        
        if is_correct:
            print(f"Result: ‚úì PASS - {reason}")
            correct_count += 1
        else:
            print(f"Result: ‚úó FAIL - {reason}")
        
        results.append({
            "test_case": i,
            "input": test_case['input'],
            "expected": test_case['expected'],
            "biogpt_output": result['biogpt_output'],
            "extracted": extracted,
            "confidence": confidence,
            "is_correct": is_correct,
            "reason": reason
        })
    
    return results, correct_count

In [14]:
# Run the tests
results, correct_count = run_tests()

TESTING SINGLE SYMPTOM EXTRACTION

Running 20 test cases...


Test Case 1/20
----------------------------------------
Input: "I have been feeling really tired recently"
Expected: fatigue
BioGPT Output: "a feeling of fullness in the abdomen, which is often accompanied by nausea, vomiting, and abdominal pain."
Extracted Symptom: a feeling of fullness
Confidence: 0.230
Result: ‚úó FAIL - No match. Expected: 'fatigue', Got: 'a feeling of fullness'

Test Case 2/20
----------------------------------------
Input: "My head is pounding and throbbing with pain"
Expected: headache
BioGPT Output: "pain in My head."
Extracted Symptom: pain in my head
Confidence: 0.258
Result: ‚úó FAIL - No match. Expected: 'headache', Got: 'pain in my head'

Test Case 3/20
----------------------------------------
Input: "I feel like I'm going to throw up"
Expected: nausea
BioGPT Output: "pain in the neck."
Extracted Symptom: pain in the neck
Confidence: 0.170
Result: ‚úó FAIL - No match. Expected: 'nausea', Got: 'p

## 5. Performance Summary

In [16]:
def display_summary(results, correct_count):
    """
    Display performance summary.
    """
    print("\n" + "="*80)
    print("PERFORMANCE SUMMARY")
    print("="*80)
    
    total_tests = len(test_cases)
    accuracy = (correct_count / total_tests) * 100
    
    print(f"\nTotal Test Cases: {total_tests}")
    print(f"Correct Predictions: {correct_count}")
    print(f"Incorrect Predictions: {total_tests - correct_count}")
    print(f"\nAccuracy: {accuracy:.1f}%")
    
    # Performance grade
    print("\nPerformance Grade:")
    if accuracy >= 90:
        grade = "Excellent"
    elif accuracy >= 75:
        grade = "Good"
    elif accuracy >= 60:
        grade = "Acceptable"
    elif accuracy >= 50:
        grade = "Needs Improvement"
    else:
        grade = "Poor"
    
    print(f" {grade} ({accuracy:.1f}%)")
    
    # Show failed cases
    failed_cases = [r for r in results if not r['is_correct']]
    if failed_cases:
        print(f"\nFailed Cases ({len(failed_cases)}):")
        print("-" * 40)
        for case in failed_cases:
            print(f"\nCase {case['test_case']}:")
            print(f"  Input: \"{case['input']}\"")
            print(f"  Expected: '{case['expected']}'")
            print(f"  Got: '{case['extracted']}'")
    
    return accuracy

# Display the summary
accuracy = display_summary(results, correct_count)


PERFORMANCE SUMMARY

Total Test Cases: 20
Correct Predictions: 5
Incorrect Predictions: 15

Accuracy: 25.0%

Performance Grade:
 Poor (25.0%)

Failed Cases (15):
----------------------------------------

Case 1:
  Input: "I have been feeling really tired recently"
  Expected: 'fatigue'
  Got: 'a feeling of fullness'

Case 2:
  Input: "My head is pounding and throbbing with pain"
  Expected: 'headache'
  Got: 'pain in my head'

Case 3:
  Input: "I feel like I'm going to throw up"
  Expected: 'nausea'
  Got: 'pain in the neck'

Case 4:
  Input: "My ankle feels swollen"
  Expected: 'ankle swelling'
  Got: 'pain'

Case 6:
  Input: "My stomach hurts really bad"
  Expected: 'abdominal pain'
  Got: 'pain in the epigastrium'

Case 8:
  Input: "I have red bumps all over my skin"
  Expected: 'rash'
  Got: 'pruritus'

Case 9:
  Input: "I feel extremely tired and have no energy"
  Expected: 'fatigue'
  Got: 'that the patient is'

Case 10:
  Input: "My throat is really sore when I swallow"
  Expec

## 6. Interactive Testing

In [None]:
# def interactive_test():
#     """
#     Interactive testing - enter your own symptoms.
#     """
#     print("\n" + "="*80)
#     print("INTERACTIVE SINGLE SYMPTOM EXTRACTION")
#     print("="*80)
#     print("Enter patient descriptions to extract the primary symptom.")
#     print("Type 'quit' to exit.\n")
    
#     while True:
#         user_input = input("\nEnter symptom description: ")
        
#         if user_input.lower() in ['quit', 'exit', 'q']:
#             print("Goodbye!")
#             break
        
#         if not user_input.strip():
#             continue
        
#         # Extract symptom
#         result = extract_symptoms(user_input)
        
#         print("\nResults:")
#         print("-" * 40)
#         print(f"BioGPT Output: \"{result['biogpt_output']}\"")
#         print(f"Extracted Symptom: {result['extracted_symptom']['symptom']}")
#         print(f"Confidence: {result['extracted_symptom']['confidence']:.3f}")

# # Uncomment to run interactive mode
# # interactive_test()

## 7. Save Results

In [None]:
# # Save detailed results to JSON file
# import json
# from datetime import datetime

# # Prepare results for saving
# save_data = {
#     "test_date": datetime.now().isoformat(),
#     "model": "microsoft/biogpt",
#     "mode": "single_symptom_extraction",
#     "accuracy": accuracy,
#     "correct_count": correct_count,
#     "total_tests": len(test_cases),
#     "detailed_results": results
# }

# # Save to file
# with open('single_symptom_test_results.json', 'w') as f:
#     json.dump(save_data, f, indent=2)

# print(f"\nüìÅ Detailed results saved to 'single_symptom_test_results.json'")
# print(f"\nüéØ Final Accuracy: {accuracy:.1f}%")