<a href="https://colab.research.google.com/github/DarthCoder501/GAAP/blob/main/Baseline_Impressions_Model_w_confidence_scores.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install tensorflow transformers scikit-learn



In [2]:
import pandas as pd
import numpy as np
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModel
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Input, Dropout
from sklearn.metrics import roc_auc_score, classification_report, multilabel_confusion_matrix
import warnings
warnings.filterwarnings('ignore')

In [3]:
# Load data
train = pd.read_csv("/content/Train Chronological (1).csv", parse_dates=["note_DATETIME"])
test = pd.read_csv("/content/Impressions Chronological Data (1).csv", parse_dates=["note_DATETIME"])

In [4]:
# Abbreviations dictionary
abbreviations = {
    "MM": "millimeter",
    "CT": "computed tomography",
    "HCW": "healthcare Worker",
    "BAC": "bronchioloalveolar carcinoma",
    "ED": "emergency department",
    "PT": "pacific time",
    "MRI": "magnetic resonance imaging",
    "RN": "registered nurse",
    "PACS": "picture archiving and communication system",
    "PET": "positron emission tomography",
    "SVC": "superior vena cava",
    "CM": "centimeter",
    "RLL": "right lower lobe",
    "RUL": "right upper lobe",
    "LAD": "left anterior descending artery",
    "TB": "tuberculosis",
    "IPMT": "intraductal papillary mucinous tumor",
    "IVC": "inferior vena cava",
    "PE": "pulmonary embolism",
    "PEs": "pulmonary embolisms",
    "FDG": "fluorodeoxyglucose",
    "SFV": "superficial femoral vein",
    "DVT": "deep vein thrombosis",
    "SMA": "superior mesenteric artery",
    "NSIP": "nonspecific interstitial pneumonia",
    "SITU": "in its original place",
    "HR": "hour",
    "4A": "the superior part of the left medial segment of the liver",
    "PST": "pacific standard time",
    "ID": "identification",
    "CTA": "computed tomography angiography",
    "NG": "nasogastric",
    "IPMN": "intraductal papillary mucinous neoplasm",
    "UIP": "usual interstitial pneumonia",
    "ER": "emergency room",
    "ARDS": "acute respiratory distress syndrome",
    "MRN": "medical record number",
    "RV": "right ventricular",
    "CHF": "congestive heart failure",
    "PEG": "percutaneous endoscopic gastrostomy",
    "PICC": "peripherally inserted central catheter",
    "GI": "gastrointestinal",
    "ASD": "atrial septal defect",
    "MR": "mitral regurgitation",
    "EST": "eastern standard time",
    "CTs": "computed tomographies",
    "3D": "three dimensional",
    "MAC": "mycobacterium avium complex",
    "MICU": "medical intensive care unit",
    "MAI": "mycobacterium avium-intracellulare",
    "PJP": "pneumocystis jirovecii pneumonia",
    "LIMA": "left internal mammary artery",
    "LV": "left ventricle",
    "EGD": "esophagogastroduodenoscopy",
    "PAU": "penetrating atherosclerotic ulcer",
    "VP": "ventriculoperitoneal",
    "CSF": "cerebrospinal fluid",
    "HCC": "hepatocellular carcinoma",
    "SABR": "stereotactic ablative radiotherapy",
    "ILD": "interstitial lung disease",
    "IVP": "intravenous pyelogram",
    "MRCP": "magnetic resonance cholangiopancreatography",
    "IV": "intravenous",
    "RCA": "right coronary artery",
    "COVID": "coronavirus disease",
    "2D": "two dimensional",
    "SMV": "superior mesenteric vein",
    "FNA": "fine needle aspiration",
    "BAL": "bronchoalveolar Lavage",
    "AVMs": "arteriovenous malformations",
    "AVM": "arteriovenous malformation",
    "MRA": "magnetic resonance angiography",
    "AP": "anteroposterior",
    "MRIs": "magnetic resonance imaging",
    "COVID19": "coronavirus disease 2019",
    "BHD": "birt-hogg-dube",
    "CTEPH": "chronic thromboembolic pulmonary hypertension",
    "RML": "right middle lobe",
    "NGT": "nasogastric tube",
    "GE": "gastroesophageal",
    "MDS": "myelodysplastic syndrome",
    "UVJ": "ureterovesical junction",
    "ERCP": "endoscopic retrograde cholangiopancreatography",
    "OP": "organizing pneumonia",
    "IJ": "internal jugular",
    "VSD": "ventricular septal defect",
    "EMR": "electronic medical record",
    "TE": "tracheoesophageal",
    "AV": "arteriovenous",
    "PAN": "polyarteritis nodosa",
    "III": "third",
    "SLE": "systemic lupus erythematosus",
    "CTS": "computed tomographies",
    "IPF": "idiopathic pulmonary fibrosis",
    "3MM": "three millimeters",
    "4MM": "four millimeters",
    "PAPVR": "partial anomalous pulmonary venous return",
    "ANCA": "antineutrophil cytoplasmic antibodies",
    "VQ": "ventilation-perfusion",
    "PA": "pulmonary artery",
    "PCP": "pneumocystis pneumonia",
    "CMV": "cytomegalovirus",
    "RVH": "right ventricular hypertrophy",
    "TSH": "thyroid stimulating hormone",
    "CBD": "common bile duct",
    "BNP": "brain natriuretic peptide",
    "16MM": "sixteen millimeters",
    "NP": "nurse practitioner",
    "CVC": "central venous catheter",
    "SVG": "saphenous vein graft",
    "PDA": "posterior descending artery",
    "VIII": "eighth",
    "ICU": "intensive care unit",
    "CPR": "cardiopulmonary resuscitation",
    "DAH": "diffuse alveolar hemorrhage",
    "PAP": "pulmonary alveolar proteinosis",
    "II": "second",
    "ENT": "ear, nose, and throat",
    "FNH": "focal nodular hyperplasia",
    "LLL": "left lower lobe",
    "CTPA": "computed tomography pulmonary angiography",
    "LA": "left atrium",
    "ABPA": "allergic bronchopulmonary aspergillosis",
    "IMA": "inferior mesenteric artery",
    "RT": "right",
    "CCU": "coronary care unit",
    "ALS": "amyotrophic lateral sclerosis",
    "LT": "left",
    "RCC": "renal cell carcinoma",
    "AML": "angiomyolipoma",
    "HCG": "human chorionic gonadotropin",
    "IJV": "internal jugular vein",
    "LE": "lower extremity",
    "ASAP": "as soon as possible",
    "1L": "one liter",
    "IHSS": "Idiopathic hypertrophic subaortic stenosis",
    "13MM": "thirteen millimeters",
    "PFO": "patent foramen ovale",
    "CCA": "common carotid artery",
    "SCA": "subclavian artery",
    "ANS": "anteromedial basal subsegmental artery",
    "IgG4": "Immunoglobulin G4",
    "ICD": "implantable cardioverter-defibrillator",
    "T9": "ninth thoracic vertebrae",
    "CVICU": "cardiovascular intensive care unit",
    "T12": "twelfth thoracic vertebra",
    "L5": "fifth lumbar vertebra",
    "L1": "first lumbar vertebra",
    "L3": "third lumbar vertebra",
    "T4": "fourth thoracic vertebrae",
    "T5": "fifth thoracic vertebra",
    "T7": "seventh thoracic vertebra",
    "T8": "eighth thoracic vertebra",
    "T10": "tenth thoracic vertebra",
    "L2": "second lumbar vertebra",
    "8MM": "eight millimeters",
    "T2": "second thoracic vertebra",
    "IUD": "intrauterine device",
    "T3": "third thoracic vertebrae",
    "T6": "sixth thoracic vertebrae",
    "C7": "seventh cervical vertebra",
    "S4": "fourth heart sound",
    "T11": "eleventh thoracic vertebra",
    "L4": "fourth lumbar vertebra",
    "T1": "first thoracic vertebra",
    "S1": "first heart sound",
    "PAH": "pulmonary arterial hypertension",
    "S9": "ninth heart sound",
    "IMH": "intramural hematoma",
    "VATS": "video-assisted thoracoscopic surgery",
    "S2": "second heart sound",
    "LVAD": "left ventricular assist device",
}

In [5]:
# Data preprocessing
print("Preprocessing text data...")
train["impressions_clean"] = train["impressions"].replace(abbreviations, regex=True)
test["impressions_clean"] = test["impressions"].replace(abbreviations, regex=True)

Preprocessing text data...


In [6]:
# Clean text
train["impressions_clean"] = train["impressions_clean"].str.lower().str.replace(r'[^a-z0-9\s]', '', regex=True)
test["impressions_clean"] = test["impressions_clean"].str.lower().str.replace(r'[^a-z0-9\s]', '', regex=True)

In [7]:
# Handle missing values
train = train.dropna(subset=["impressions_clean"])
test = test.dropna(subset=["impressions_clean"])

In [8]:
# Remove empty strings after cleaning
train = train[train["impressions_clean"].str.strip() != ""]
test = test[test["impressions_clean"].str.strip() != ""]

print(f"Train samples after cleaning: {len(train)}")
print(f"Test samples after cleaning: {len(test)}")

Train samples after cleaning: 4460
Test samples after cleaning: 1116


In [9]:
# Load tokenizer and model
print("Loading Bio_ClinicalBERT...")
tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
bert_model = TFAutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

Loading Bio_ClinicalBERT...


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tf_model.h5:   0%|          | 0.00/527M [00:00<?, ?B/s]

TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.
Some layers from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT were not used when initializing TFBertModel: ['mlm___cls', 'nsp___cls']
- This IS expected if you are initializing TFBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFBertModel were initialized from the model checkpoint at emilyalsentzer/Bio_ClinicalBERT.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predicti

In [10]:
# Tokenize text
max_seq_length = 128
print("Tokenizing text...")
train_encodings = tokenizer(
    train["impressions_clean"].tolist(),
    padding=True,
    truncation=True,
    max_length=max_seq_length,
    return_tensors="tf"
)
test_encodings = tokenizer(
    test["impressions_clean"].tolist(),
    padding=True,
    truncation=True,
    max_length=max_seq_length,
    return_tensors="tf"
)

TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.


Tokenizing text...


In [11]:
# Extract embeddings function (fixed)
def extract_embeddings_in_batches(model, encodings, batch_size=32):
    """Extract CLS token embeddings in batches to manage memory"""
    all_embeddings = []
    num_samples = encodings.input_ids.shape[0]

    for i in range(0, num_samples, batch_size):
        end_idx = min(i + batch_size, num_samples)

        batch_input_ids = encodings.input_ids[i:end_idx]
        batch_attention_mask = encodings.attention_mask[i:end_idx]

        # Get embeddings for this batch
        outputs = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
        # Use CLS token (first token) embeddings
        batch_embeddings = outputs.last_hidden_state[:, 0, :]
        all_embeddings.append(batch_embeddings.numpy())

        if (i // batch_size + 1) % 10 == 0:
            print(f"Processed {i + batch_size} / {num_samples} samples")

    return np.concatenate(all_embeddings, axis=0)

In [12]:
# Extract embeddings
print("Extracting embeddings...")
batch_size = 32  # Reduced batch size for stability
X_train = extract_embeddings_in_batches(bert_model, train_encodings, batch_size)
X_test = extract_embeddings_in_batches(bert_model, test_encodings, batch_size)

Extracting embeddings...
Processed 320 / 4460 samples
Processed 640 / 4460 samples
Processed 960 / 4460 samples
Processed 1280 / 4460 samples
Processed 1600 / 4460 samples
Processed 1920 / 4460 samples
Processed 2240 / 4460 samples
Processed 2560 / 4460 samples
Processed 2880 / 4460 samples
Processed 3200 / 4460 samples
Processed 3520 / 4460 samples
Processed 3840 / 4460 samples
Processed 4160 / 4460 samples
Processed 4480 / 4460 samples
Processed 320 / 1116 samples
Processed 640 / 1116 samples
Processed 960 / 1116 samples


In [13]:
print(f"Shape of X_train: {X_train.shape}")
print(f"Shape of X_test: {X_test.shape}")

Shape of X_train: (4460, 768)
Shape of X_test: (1116, 768)


In [14]:
# Define targets
targets = ["1_month_readmission", "6_month_readmission", "12_month_readmission", "pe_positive"]

# Check if target columns exist and handle missing values
print("Preparing target variables...")
for target in targets:
    if target not in train.columns:
        print(f"Warning: {target} not found in train data")
    if target not in test.columns:
        print(f"Warning: {target} not found in test data")

# Create target arrays & handle missing columns
available_targets = [t for t in targets if t in train.columns and t in test.columns]
print(f"Available targets: {available_targets}")

if not available_targets:
    print("No target columns found! Please check column names.")
    # Print available columns for debugging
    print("Train columns:", train.columns.tolist())
    print("Test columns:", test.columns.tolist())
else:
    y_train = train[available_targets].fillna(0).astype(int)
    y_test = test[available_targets].fillna(0).astype(int)

    print("Target distribution in training data:")
    for target in available_targets:
        pos_count = y_train[target].sum()
        total_count = len(y_train)
        print(f"{target}: {pos_count}/{total_count} ({pos_count/total_count*100:.1f}% positive)")

Preparing target variables...
Available targets: ['1_month_readmission', '6_month_readmission', '12_month_readmission', 'pe_positive']
Target distribution in training data:
1_month_readmission: 201/4460 (4.5% positive)
6_month_readmission: 593/4460 (13.3% positive)
12_month_readmission: 813/4460 (18.2% positive)
pe_positive: 1096/4460 (24.6% positive)


In [15]:
# Build the multi-label classification model
print("Building model...")
input_layer = Input(shape=(X_train.shape[1],), name='embeddings_input')
dense1 = Dense(256, activation="relu", name='dense1')(input_layer)
dropout1 = Dropout(0.3, name='dropout1')(dense1)
dense2 = Dense(128, activation="relu", name='dense2')(dropout1)
dropout2 = Dropout(0.2, name='dropout2')(dense2)

# Create task-specific output heads
outputs = []
for target in available_targets:
    output = Dense(1, activation="sigmoid", name=target)(dropout2)
    outputs.append(output)

# Create and compile model
classification_model = Model(inputs=input_layer, outputs=outputs)

# For multi-output models, we need to provide metrics for each output
if len(available_targets) == 1:
    # Single output - simple metrics list
    classification_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )
else:
    # Multiple outputs - provide metrics list matching number of outputs
    # Each entry in the list corresponds to one output
    metrics_list = [['accuracy'] for _ in range(len(available_targets))]
    classification_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss='binary_crossentropy',
        metrics=metrics_list
    )

Building model...


In [16]:
print("Model architecture:")
classification_model.summary()

Model architecture:


In [17]:
# Training
print("Training model...")
history = classification_model.fit(
    X_train,
    [y_train[target].values for target in available_targets],
    validation_data=(X_test, [y_test[target].values for target in available_targets]),
    epochs=10,
    batch_size=32,
    verbose=1
)

Training model...
Epoch 1/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 30ms/step - 12_month_readmission_accuracy: 0.8052 - 12_month_readmission_loss: 0.5073 - 1_month_readmission_accuracy: 0.9416 - 1_month_readmission_loss: 0.2348 - 6_month_readmission_accuracy: 0.8268 - 6_month_readmission_loss: 0.4479 - loss: 1.6943 - pe_positive_accuracy: 0.7493 - pe_positive_loss: 0.5043 - val_12_month_readmission_accuracy: 0.7679 - val_12_month_readmission_loss: 0.5486 - val_1_month_readmission_accuracy: 0.9194 - val_1_month_readmission_loss: 0.2804 - val_6_month_readmission_accuracy: 0.8172 - val_6_month_readmission_loss: 0.4838 - val_loss: 1.6342 - val_pe_positive_accuracy: 0.8593 - val_pe_positive_loss: 0.3238
Epoch 2/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 5ms/step - 12_month_readmission_accuracy: 0.8029 - 12_month_readmission_loss: 0.5056 - 1_month_readmission_accuracy: 0.9534 - 1_month_readmission_loss: 0.1913 - 6_month_readmission_accura

In [18]:
# ENHANCED EVALUATION WITH CONFIDENCE SCORES

print("\nEvaluating model and extracting confidence scores...")

# Get predictions (these are the confidence scores!)
predictions = classification_model.predict(X_test, batch_size=32)

# Ensure predictions is a list for multi-output models
if len(available_targets) == 1:
    predictions = [predictions]

# Create a comprehensive results dataframe
results_df = test.copy()

# Add confidence scores and binary predictions for each target
for i, target in enumerate(available_targets):
    confidence_scores = predictions[i].ravel()
    binary_predictions = (confidence_scores > 0.5).astype(int)

    # Add to results dataframe
    results_df[f'{target}_confidence'] = confidence_scores
    results_df[f'{target}_prediction'] = binary_predictions
    results_df[f'{target}_true'] = y_test[target].values

# Save results with confidence scores
results_df.to_csv('test_results_with_confidence.csv', index=False)
print("Results with confidence scores saved to 'test_results_with_confidence.csv'")


Evaluating model and extracting confidence scores...
[1m35/35[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step
Results with confidence scores saved to 'test_results_with_confidence.csv'


In [19]:
# CONFIDENCE SCORE ANALYSIS

print("\n" + "="*60)
print("CONFIDENCE SCORE ANALYSIS")
print("="*60)

# Calculate metrics for each target
results = {}
for i, target in enumerate(available_targets):
    y_true = y_test[target].values
    confidence_scores = predictions[i].ravel()
    y_pred_binary = (confidence_scores > 0.5).astype(int)

    # Calculate AUC-ROC (handle edge case where all labels are the same class)
    try:
        if len(np.unique(y_true)) > 1:
            auc_score = roc_auc_score(y_true, confidence_scores)
        else:
            auc_score = np.nan
            print(f"Warning: {target} has only one class in test set, AUC cannot be calculated")
    except Exception as e:
        auc_score = np.nan
        print(f"Error calculating AUC for {target}: {e}")

    results[target] = {
        'auc_roc': auc_score,
        'accuracy': np.mean(y_true == y_pred_binary),
        'precision': np.sum((y_pred_binary == 1) & (y_true == 1)) / max(np.sum(y_pred_binary == 1), 1),
        'recall': np.sum((y_pred_binary == 1) & (y_true == 1)) / max(np.sum(y_true == 1), 1),
        'confidence_stats': {
            'mean': np.mean(confidence_scores),
            'std': np.std(confidence_scores),
            'min': np.min(confidence_scores),
            'max': np.max(confidence_scores)
        }
    }

    print(f"\n{target} Results:")
    print(f"  AUC-ROC: {auc_score:.4f}" if not np.isnan(auc_score) else "  AUC-ROC: N/A")
    print(f"  Accuracy: {results[target]['accuracy']:.4f}")
    print(f"  Precision: {results[target]['precision']:.4f}")
    print(f"  Recall: {results[target]['recall']:.4f}")

    # Confidence score statistics
    print(f"  Confidence Scores:")
    print(f"    Mean: {results[target]['confidence_stats']['mean']:.4f}")
    print(f"    Std:  {results[target]['confidence_stats']['std']:.4f}")
    print(f"    Min:  {results[target]['confidence_stats']['min']:.4f}")
    print(f"    Max:  {results[target]['confidence_stats']['max']:.4f}")

    # Classification report
    print(f"\nClassification Report for {target}:")
    print(classification_report(y_true, y_pred_binary, zero_division=0))


CONFIDENCE SCORE ANALYSIS

1_month_readmission Results:
  AUC-ROC: 0.5778
  Accuracy: 0.9194
  Precision: 0.0000
  Recall: 0.0000
  Confidence Scores:
    Mean: 0.0399
    Std:  0.0236
    Min:  0.0010
    Max:  0.1704

Classification Report for 1_month_readmission:
              precision    recall  f1-score   support

           0       0.92      1.00      0.96      1026
           1       0.00      0.00      0.00        90

    accuracy                           0.92      1116
   macro avg       0.46      0.50      0.48      1116
weighted avg       0.85      0.92      0.88      1116


6_month_readmission Results:
  AUC-ROC: 0.5257
  Accuracy: 0.8172
  Precision: 0.0000
  Recall: 0.0000
  Confidence Scores:
    Mean: 0.1388
    Std:  0.0527
    Min:  0.0218
    Max:  0.2902

Classification Report for 6_month_readmission:
              precision    recall  f1-score   support

           0       0.82      1.00      0.90       912
           1       0.00      0.00      0.00       204



In [20]:
# CONFIDENCE THRESHOLD ANALYSIS

print("\n" + "="*60)
print("CONFIDENCE THRESHOLD ANALYSIS")
print("="*60)

# Analyze performance at different confidence thresholds
thresholds = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]

for target_idx, target in enumerate(available_targets):
    print(f"\n{target} - Performance at different confidence thresholds:")
    print("Threshold | Accuracy | Precision | Recall | F1-Score")
    print("-" * 55)

    y_true = y_test[target].values
    confidence_scores = predictions[target_idx].ravel()

    for threshold in thresholds:
        y_pred_thresh = (confidence_scores > threshold).astype(int)

        accuracy = np.mean(y_true == y_pred_thresh)
        precision = np.sum((y_pred_thresh == 1) & (y_true == 1)) / max(np.sum(y_pred_thresh == 1), 1)
        recall = np.sum((y_pred_thresh == 1) & (y_true == 1)) / max(np.sum(y_true == 1), 1)
        f1 = 2 * (precision * recall) / max(precision + recall, 1e-8)

        print(f"   {threshold:.1f}    |   {accuracy:.3f}  |   {precision:.3f}   |  {recall:.3f}  |  {f1:.3f}")


CONFIDENCE THRESHOLD ANALYSIS

1_month_readmission - Performance at different confidence thresholds:
Threshold | Accuracy | Precision | Recall | F1-Score
-------------------------------------------------------
   0.3    |   0.919  |   0.000   |  0.000  |  0.000
   0.4    |   0.919  |   0.000   |  0.000  |  0.000
   0.5    |   0.919  |   0.000   |  0.000  |  0.000
   0.6    |   0.919  |   0.000   |  0.000  |  0.000
   0.7    |   0.919  |   0.000   |  0.000  |  0.000
   0.8    |   0.919  |   0.000   |  0.000  |  0.000
   0.9    |   0.919  |   0.000   |  0.000  |  0.000

6_month_readmission - Performance at different confidence thresholds:
Threshold | Accuracy | Precision | Recall | F1-Score
-------------------------------------------------------
   0.3    |   0.817  |   0.000   |  0.000  |  0.000
   0.4    |   0.817  |   0.000   |  0.000  |  0.000
   0.5    |   0.817  |   0.000   |  0.000  |  0.000
   0.6    |   0.817  |   0.000   |  0.000  |  0.000
   0.7    |   0.817  |   0.000   |  0

In [21]:
# HIGH/LOW CONFIDENCE SAMPLE ANALYSIS

print("\n" + "="*60)
print("HIGH/LOW CONFIDENCE SAMPLE ANALYSIS")
print("="*60)

for target_idx, target in enumerate(available_targets):
    confidence_scores = predictions[target_idx].ravel()
    y_true = y_test[target].values

    # Find high confidence correct predictions
    high_conf_correct = np.where((confidence_scores > 0.8) &
                                ((confidence_scores > 0.5) == y_true))[0]

    # Find low confidence predictions
    low_conf_predictions = np.where((confidence_scores > 0.3) &
                                   (confidence_scores < 0.7))[0]

    print(f"\n{target}:")
    print(f"  High confidence correct predictions: {len(high_conf_correct)}")
    print(f"  Low confidence predictions (0.3-0.7): {len(low_conf_predictions)}")

    if len(high_conf_correct) > 0:
        print(f"  Sample high confidence scores: {confidence_scores[high_conf_correct[:5]]}")

    if len(low_conf_predictions) > 0:
        print(f"  Sample low confidence scores: {confidence_scores[low_conf_predictions[:5]]}")


HIGH/LOW CONFIDENCE SAMPLE ANALYSIS

1_month_readmission:
  High confidence correct predictions: 0
  Low confidence predictions (0.3-0.7): 0

6_month_readmission:
  High confidence correct predictions: 0
  Low confidence predictions (0.3-0.7): 0

12_month_readmission:
  High confidence correct predictions: 0
  Low confidence predictions (0.3-0.7): 6
  Sample low confidence scores: [0.31311738 0.30439335 0.34502783 0.3295886  0.31596285]

pe_positive:
  High confidence correct predictions: 90
  Low confidence predictions (0.3-0.7): 105
  Sample high confidence scores: [0.96669936 0.85562795 0.84882694 0.92313653 0.9216634 ]
  Sample low confidence scores: [0.5562965  0.36007357 0.40486163 0.5080258  0.4461701 ]


In [22]:
# FUNCTION TO GET CONFIDENCE FOR NEW DATA

def get_confidence_predictions(model, tokenizer, texts, max_length=128, batch_size=32):
    """
    Get confidence predictions for new text data

    Args:
        model: Trained classification model
        tokenizer: BERT tokenizer
        texts: List of text strings
        max_length: Maximum sequence length
        batch_size: Batch size for processing

    Returns:
        Dictionary with confidence scores for each target
    """
    # Tokenize
    encodings = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="tf"
    )

    # Extract embeddings
    embeddings = extract_embeddings_in_batches(bert_model, encodings, batch_size)

    # Get predictions
    predictions = model.predict(embeddings, batch_size=batch_size)

    # Format results
    if len(available_targets) == 1:
        predictions = [predictions]

    results = {}
    for i, target in enumerate(available_targets):
        results[target] = {
            'confidence_scores': predictions[i].ravel(),
            'binary_predictions': (predictions[i].ravel() > 0.5).astype(int)
        }

    return results

# Example usage of the confidence function
print("\n" + "="*60)
print("EXAMPLE: Getting confidence for new texts")
print("="*60)

# Example texts (you can replace with your own)
example_texts = [
    "patient shows signs of improvement",
    "concerning findings require immediate attention",
    "normal examination results"
]

if len(example_texts) > 0:
    example_results = get_confidence_predictions(
        classification_model,
        tokenizer,
        example_texts
    )

    for i, text in enumerate(example_texts):
        print(f"\nText {i+1}: '{text}'")
        for target in available_targets:
            conf = example_results[target]['confidence_scores'][i]
            pred = example_results[target]['binary_predictions'][i]
            print(f"  {target}: {conf:.4f} (prediction: {pred})")

# Overall model performance
print("\n" + "="*50)
print("OVERALL MODEL PERFORMANCE SUMMARY")
print("="*50)

avg_auc = np.nanmean([results[target]['auc_roc'] for target in available_targets])
avg_accuracy = np.mean([results[target]['accuracy'] for target in available_targets])
avg_precision = np.mean([results[target]['precision'] for target in available_targets])
avg_recall = np.mean([results[target]['recall'] for target in available_targets])

print(f"Average AUC-ROC: {avg_auc:.4f}")
print(f"Average Accuracy: {avg_accuracy:.4f}")
print(f"Average Precision: {avg_precision:.4f}")
print(f"Average Recall: {avg_recall:.4f}")

print("\nModel training completed successfully!")
print("Confidence scores have been extracted and saved!")


EXAMPLE: Getting confidence for new texts
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 254ms/step

Text 1: 'patient shows signs of improvement'
  1_month_readmission: 0.0195 (prediction: 0)
  6_month_readmission: 0.1059 (prediction: 0)
  12_month_readmission: 0.1356 (prediction: 0)
  pe_positive: 0.0456 (prediction: 0)

Text 2: 'concerning findings require immediate attention'
  1_month_readmission: 0.0128 (prediction: 0)
  6_month_readmission: 0.0715 (prediction: 0)
  12_month_readmission: 0.1070 (prediction: 0)
  pe_positive: 0.0027 (prediction: 0)

Text 3: 'normal examination results'
  1_month_readmission: 0.0083 (prediction: 0)
  6_month_readmission: 0.0518 (prediction: 0)
  12_month_readmission: 0.0801 (prediction: 0)
  pe_positive: 0.0006 (prediction: 0)

OVERALL MODEL PERFORMANCE SUMMARY
Average AUC-ROC: 0.6481
Average Accuracy: 0.8533
Average Precision: 0.2207
Average Recall: 0.1582

Model training completed successfully!
Confidence scores have been extracted