# LLM Response Analysis: Cancer Survival Prediction (Improved)

This notebook analyzes the LLM's predictions with an **improved prediction extraction** that:
1. Handles negation (e.g., 'not survive' -> 0)
2. Detects conflicting information (e.g., says '1' but text mentions 'death')

**Metrics computed:**
- Confusion Matrix
- Accuracy, Precision, Recall, F1-Score
- Classification Report by cancer system

In [None]:
import json
import re
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    confusion_matrix, 
    classification_report, 
    accuracy_score,
    precision_score, 
    recall_score, 
    f1_score,
    ConfusionMatrixDisplay
)

# Set plot style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

## 1. Load Results

In [None]:
# Load the results file
RESULTS_FILE = 'results_rag.json'  # Update path if needed

with open(RESULTS_FILE, 'r') as f:
    results = json.load(f)

print(f"Loaded {len(results)} results")
print(f"\nSample result keys: {list(results[0].keys())}")

## 2. Improved Prediction Extraction

Key improvements:
- **Conflict Detection**: Tracks both survival and death signals separately
- **Negation Handling**: 'not survive' -> 0, 'not die' -> 1
- **End-of-text Priority**: Trusts explicit 0/1 at the end of response as tie-breaker

In [None]:
def extract_prediction(response: str) -> int | None:
    """
    Extract the binary prediction (0 or 1) from the LLM response.
    Returns None if:
    1. No prediction is found
    2. Conflicting information is detected (e.g., explicit '1' but text says 'death')
    """
    response = response.strip().lower()
    
    # Track detected signals
    detected_1 = False
    detected_0 = False
    
    # 1. Check for Explicit Match at End (Strongest Signal)
    end_match = re.search(r'\b([01])\s*[.]?(\s*<\|assistant_end\|>)?$', response)
    if end_match:
        val = int(end_match.group(1))
        if val == 1: detected_1 = True
        else: detected_0 = True
    
    # 2. Check for 'Prediction: X' patterns
    patterns = [
        r'(?:prediction|answer|result|outcome|value)[\s:]+([01])\b',
        r'\breturn\s+([01])\b',
    ]
    for pattern in patterns:
        match = re.search(pattern, response)
        if match:
            val = int(match.group(1))
            if val == 1: detected_1 = True
            else: detected_0 = True

    # 3. Check for Textual Keywords
    
    # Negation Handling (CRITICAL)
    if re.search(r'\b(not|unlikely to|fail to|did not) (survive|live|recover)\b', response):
        detected_0 = True
    if re.search(r'\b(not|unlikely to) (die|succumb)\b', response):
        detected_1 = True
        
    # Standard Keywords
    if re.search(r'\b(survive|survival|lives|alive)\b', response):
        detected_1 = True
    if re.search(r'\b(die|death|dead|deceased|succumb)\b', response):
        detected_0 = True
    
    # 4. Resolve Conflict
    if detected_1 and detected_0:
        # CONFLICT! Trust end-of-text match if present
        if end_match:
            return int(end_match.group(1))
        return None  # Genuinely contradictory
        
    if detected_1: return 1
    if detected_0: return 0
    
    return None

# Test the extraction function
test_responses = [
    "1",
    "0",
    "The patient will survive.",
    "The patient will not survive.",
    "Based on the data, I predict 0 (death).",
    "The patient is predicted to die.0<|assistant_end|>",  # Conflict resolved by end
    "The patient will survive but may die soon.",  # Conflict, no end -> None
]

print("Testing extraction function:")
for resp in test_responses:
    print(f"  '{resp[:50]}...' -> {extract_prediction(resp)}")

In [None]:
# Extract predictions and ground truth
data = []

for item in results:
    prediction = extract_prediction(item['response'])
    
    # Ground truth: died_from_cancer (1 = died from cancer, 0 = alive or died from other cause)
    # Note: We're predicting SURVIVAL (1 = survive, 0 = death)
    # So we need to invert: if died_from_cancer=1, then survival=0
    died_from_cancer = int(item['ground_truth']['died_from_cancer'])
    ground_truth_survival = 1 - died_from_cancer  # Invert for survival prediction
    
    data.append({
        'patient_id': item.get('patient_id'),
        'response': item['response'],
        'prediction': prediction,
        'ground_truth_survival': ground_truth_survival,
        'died_from_cancer': died_from_cancer,
        'is_alive': int(item['ground_truth']['is_alive']),
        'cancer_system': item.get('features', {}).get('cancer_system', 'Unknown'),
        'survival_months': float(item['ground_truth'].get('survival_months', 0)),
    })

df = pd.DataFrame(data)
print(f"Total samples: {len(df)}")
print(f"Samples with valid predictions: {df['prediction'].notna().sum()}")
print(f"Samples with missing/conflicting predictions: {df['prediction'].isna().sum()}")

In [None]:
# Show samples where prediction extraction failed (likely conflicts)
failed_extractions = df[df['prediction'].isna()]
if len(failed_extractions) > 0:
    print(f"\nFailed/Conflicting extractions ({len(failed_extractions)}):")
    for idx, row in failed_extractions.head(5).iterrows():
        print(f"\n  Response: {row['response'][:200]}...")
else:
    print("All predictions extracted successfully!")

In [None]:
# Filter to valid predictions only
df_valid = df[df['prediction'].notna()].copy()
df_valid['prediction'] = df_valid['prediction'].astype(int)

print(f"\nValid predictions: {len(df_valid)} / {len(df)} ({100*len(df_valid)/len(df):.1f}%)")
print(f"\nPrediction distribution:")
print(df_valid['prediction'].value_counts())
print(f"\nGround truth distribution (survival):")
print(df_valid['ground_truth_survival'].value_counts())

## 3. Confusion Matrix

In [None]:
y_true = df_valid['ground_truth_survival'].values
y_pred = df_valid['prediction'].values

# Compute confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Death (0)', 'Survival (1)'])
disp.plot(cmap='Blues', ax=ax, values_format='d')
plt.title('Confusion Matrix: LLM Survival Predictions (Improved)', fontsize=14)
plt.tight_layout()
plt.savefig('confusion_matrix_improved.png', dpi=150, bbox_inches='tight')
plt.show()

# Print raw values
print("\nConfusion Matrix:")
print(f"  TN (True Death):     {cm[0,0]}")
print(f"  FP (False Survival): {cm[0,1]}")
print(f"  FN (False Death):    {cm[1,0]}")
print(f"  TP (True Survival):  {cm[1,1]}")

## 4. Classification Metrics

In [None]:
# Calculate metrics
accuracy = accuracy_score(y_true, y_pred)
precision = precision_score(y_true, y_pred, zero_division=0)
recall = recall_score(y_true, y_pred, zero_division=0)
f1 = f1_score(y_true, y_pred, zero_division=0)

# Specificity (true negative rate)
tn, fp, fn, tp = cm.ravel()
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

# Print metrics
print("="*50)
print("CLASSIFICATION METRICS (Improved Extraction)")
print("="*50)
print(f"Accuracy:    {accuracy:.4f} ({accuracy*100:.2f}%)")
print(f"Precision:   {precision:.4f}")
print(f"Recall:      {recall:.4f} (Sensitivity)")
print(f"Specificity: {specificity:.4f}")
print(f"F1-Score:    {f1:.4f}")
print("="*50)

In [None]:
# Full classification report
print("\nDetailed Classification Report:")
print(classification_report(y_true, y_pred, target_names=['Death (0)', 'Survival (1)']))

## 5. Summary

This notebook demonstrates improved prediction extraction with:
- Negation handling
- Conflict detection
- End-of-text priority for tie-breaking