### Model Deployment using PUBMED BERT MOdel

In [18]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [23]:
#!/usr/bin/env python
# coding: utf-8

### Enhanced Medical Trial Condition Predictor by Harry Patria

# Import necessary libraries
#!pip install ipywidgets
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
import numpy as np
import torch
from torch.utils.data import Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import time
import json
import os
import re
from matplotlib.figure import Figure
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from io import BytesIO
import base64
import warnings
warnings.filterwarnings('ignore')

from collections import Counter
from transformers import AutoTokenizer

# Download a pre-trained tokenizer (e.g., bert-base-uncased)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set styling for matplotlib visualizations with transparent background
plt.rcParams.update({
    'figure.facecolor': 'none',
    'axes.facecolor': '#21212180',
    'savefig.facecolor': 'none',
    'text.color': 'white',
    'axes.labelcolor': 'white',
    'xtick.color': 'white',
    'ytick.color': 'white',
    'grid.color': '#505050',
    'figure.autolayout': True,
    'font.family': 'sans-serif',
    'font.sans-serif': ['Arial', 'Helvetica', 'DejaVu Sans'],
})

# Add enhanced CSS styling with accessibility improvements
custom_style = """
<style>
    /* Main theme variables */
    :root {
        --primary-color: #bb86fc;
        --primary-variant: #3700b3;
        --secondary-color: #03dac6;
        --error-color: #cf6679;
        --background: #121212;
        --surface: #1e1e1e;
        --on-primary: #000000;
        --on-secondary: #000000;
        --on-background: #ffffff;
        --on-surface: #ffffff;
        --on-error: #000000;
        --on-background-medium: rgba(255, 255, 255, 0.7);
        --on-background-disabled: rgba(255, 255, 255, 0.38);
        --surface-1: rgba(255, 255, 255, 0.05);
        --surface-2: rgba(255, 255, 255, 0.08);
        --surface-3: rgba(255, 255, 255, 0.11);
    }

    /* Widget styling */
    .widget-label {
        color: var(--primary-color) !important;
        font-weight: bold !important;
        font-size: 1.05em !important;
    }

    .jupyter-widgets {
        font-family: 'Segoe UI', system-ui, -apple-system, Arial, sans-serif !important;
    }

    .widget-textarea > textarea {
        background-color: var(--surface) !important;
        color: var(--on-background) !important;
        border: 1px solid #424242 !important;
        border-radius: 8px !important;
        padding: 12px !important;
        font-size: 16px !important;
        transition: border-color 0.2s ease-in-out;
        line-height: 1.5 !important;
    }

    .widget-textarea > textarea:focus {
        border-color: var(--primary-color) !important;
        box-shadow: 0 0 0 2px rgba(187, 134, 252, 0.25) !important;
        outline: none !important;
    }

    .jupyter-button {
        text-transform: uppercase !important;
        font-weight: 600 !important;
        letter-spacing: 0.5px !important;
        border-radius: 8px !important;
        padding: 8px 16px !important;
        transition: all 0.2s ease !important;
        border: none !important;
        box-shadow: 0 3px 5px rgba(0,0,0,0.2) !important;
    }

    .jupyter-button:hover {
        transform: translateY(-2px) !important;
        box-shadow: 0 5px 8px rgba(0,0,0,0.3) !important;
    }

    .jupyter-button:active {
        transform: translateY(0) !important;
        box-shadow: 0 2px 3px rgba(0,0,0,0.3) !important;
    }

    .jupyter-button.mod-primary {
        background-color: var(--primary-color) !important;
        color: var(--on-primary) !important;
    }

    .jupyter-button.mod-warning {
        background-color: #424242 !important;
        color: var(--on-background) !important;
    }

    /* Typography */
    h2, h3, h4 {
        color: var(--primary-color) !important;
        font-weight: 700 !important;
        line-height: 1.3 !important;
    }

    p {
        color: var(--on-background) !important;
        line-height: 1.6 !important;
        margin-bottom: 16px !important;
    }

    /* Components */
    .custom-header {
        background-color: var(--surface);
        padding: 24px;
        margin-bottom: 24px;
        border-radius: 12px;
        border-left: 5px solid var(--primary-color);
        box-shadow: 0 4px 20px rgba(0, 0, 0, 0.3);
    }

    .custom-footer {
        background-color: var(--surface);
        padding: 16px;
        margin-top: 32px;
        border-radius: 12px;
        text-align: center;
        font-size: 0.9em;
        color: var(--on-background-medium) !important;
        box-shadow: 0 -2px 10px rgba(0, 0, 0, 0.1);
    }

    .results-panel {
        background-color: var(--surface);
        padding: 20px;
        border-radius: 12px;
        margin-top: 24px;
        box-shadow: 0 6px 18px rgba(0, 0, 0, 0.25);
        animation: fadeIn 0.5s ease-in-out;
    }

    .error-message {
        color: var(--error-color) !important;
        background-color: rgba(207, 102, 121, 0.1);
        padding: 12px 16px;
        border-radius: 8px;
        border-left: 4px solid var(--error-color);
        margin: 16px 0;
        font-weight: 500;
        animation: slideIn 0.3s ease-in-out;
    }

    /* Loading spinner with accessibility considerations */
    .loading-container {
        display: flex;
        flex-direction: column;
        align-items: center;
        justify-content: center;
        padding: 24px;
        margin: 16px 0;
    }

    .loading-spinner {
        border: 5px solid #424242;
        border-radius: 50%;
        border-top: 5px solid var(--primary-color);
        width: 48px;
        height: 48px;
        animation: spin 1s linear infinite;
        margin-bottom: 16px;
    }

    .sr-only {
        position: absolute;
        width: 1px;
        height: 1px;
        padding: 0;
        margin: -1px;
        overflow: hidden;
        clip: rect(0, 0, 0, 0);
        white-space: nowrap;
        border-width: 0;
    }

    .model-info {
        background-color: rgba(187, 134, 252, 0.1);
        padding: 16px;
        border-radius: 8px;
        border-left: 4px solid var(--primary-color);
        margin: 16px 0;
        font-size: 0.95em;
    }

    /* Key terms highlighting */
    .key-term {
        background-color: rgba(187, 134, 252, 0.2);
        padding: 0 5px;
        border-radius: 4px;
        font-weight: 500;
    }

    /* Tooltips for accessibility */
    .tooltip {
        position: relative;
        display: inline-block;
        border-bottom: 1px dotted var(--primary-color);
        cursor: help;
    }

    .tooltip .tooltiptext {
        visibility: hidden;
        background-color: #333;
        color: #fff;
        text-align: center;
        border-radius: 6px;
        padding: 8px 12px;
        position: absolute;
        z-index: 1;
        bottom: 125%;
        left: 50%;
        transform: translateX(-50%);
        width: max-content;
        max-width: 250px;
        opacity: 0;
        transition: opacity 0.3s;
        font-size: 14px;
        line-height: 1.4;
        box-shadow: 0 4px 8px rgba(0,0,0,0.4);
    }

    .tooltip:hover .tooltiptext {
        visibility: visible;
        opacity: 0.95;
    }

    /* Prediction badges */
    .prediction-badge {
        display: inline-block;
        background-color: var(--primary-variant);
        color: white;
        padding: 6px 12px;
        border-radius: 16px;
        font-weight: 600;
        margin-left: 8px;
        font-size: 1.1em;
    }

    /* Data visualization container */
    .viz-container {
        background-color: rgba(30, 30, 30, 0.7);
        padding: 16px;
        border-radius: 12px;
        margin: 20px 0;
        border: 1px solid #424242;
    }

    /* Animation keyframes */
    @keyframes spin {
        0% { transform: rotate(0deg); }
        100% { transform: rotate(360deg); }
    }

    @keyframes fadeIn {
        0% { opacity: 0; transform: translateY(10px); }
        100% { opacity: 1; transform: translateY(0); }
    }

    @keyframes slideIn {
        0% { opacity: 0; transform: translateX(-10px); }
        100% { opacity: 1; transform: translateX(0); }
    }

    /* Responsive adjustments */
    @media (max-width: 768px) {
        .custom-header, .results-panel, .custom-footer {
            padding: 16px;
        }

        h2 {
            font-size: 1.5em !important;
        }

        .widget-label {
            font-size: 0.9em !important;
        }
    }

    /* Keyboard focus indicators for accessibility */
    :focus {
        outline: 2px solid var(--primary-color) !important;
        outline-offset: 2px !important;
    }

    /* High contrast mode support */
    @media screen and (forced-colors: active) {
        .jupyter-button, .results-panel, .custom-header, .loading-spinner {
            border: 2px solid CanvasText;
        }

        .key-term {
            border: 1px solid CanvasText;
        }
    }

    /* Removed the bug in bottom right corner by ensuring all containers have proper bounds */
    .output_area img {
        max-width: 100% !important;
        height: auto !important;
        background-color: transparent !important;
    }

    /* Make plots responsive */
    .plot-container {
        width: 100%;
        max-width: 900px;
        margin: 0 auto;
        background-color: transparent !important;
    }
</style>
"""
display(HTML(custom_style))

# Initialize model paths
MODEL_DIR = "/content/drive/MyDrive/deployment/pubmedbert"  # Updated path
LABEL_FILE = os.path.join(MODEL_DIR, "label_mapping.json")

# Load the model and tokenizer
def load_transformer_model(model_dir):
    """
    Load a transformer model and tokenizer from the specified directory

    Parameters:
    -----------
    model_dir : str
        Path to the saved model directory

    Returns:
    --------
    tuple
        (model, tokenizer, id2label, label2id)
    """
    try:
        # Load the model
        model = AutoModelForSequenceClassification.from_pretrained(model_dir)
        tokenizer = AutoTokenizer.from_pretrained(model_dir)

        # Load label mapping
        if os.path.exists(LABEL_FILE):
            with open(LABEL_FILE, 'r') as f:
                label_mapping = json.load(f)
                id2label = label_mapping['id2label']
                # Convert string keys to integers
                id2label = {int(k): v for k, v in id2label.items()}
                label2id = label_mapping['label2id']
        else:
            # Fallback to model's internal mapping
            id2label = model.config.id2label
            label2id = model.config.label2id

        return model, tokenizer, id2label, label2id

    except Exception as e:
        print(f"Error loading model: {str(e)}")
        # Create dummy mappings for demonstration
        id2label = {
            0: 'ALS',
            1: 'Dementia',
            2: 'Obsessive Compulsive Disorder',
            3: "Parkinson's Disease",
            4: 'Schizophrenia'
        }
        label2id = {v: k for k, v in id2label.items()}
        return None, None, id2label, label2id

# Define the prediction function
def predict_condition(description, model, tokenizer, id2label, device="cpu"):
    """
    Predict the medical condition for a given trial description.

    Parameters:
    -----------
    description : str
        Clinical trial description
    model : transformers.PreTrainedModel
        Trained transformer model
    tokenizer : transformers.PreTrainedTokenizer
        Tokenizer for the model
    id2label : dict
        Mapping from ID to label
    device : str
        Device to run inference on

    Returns:
    --------
    dict
        Dictionary with prediction results
    """
    try:
        # Move model to device
        model = model.to(device)
        model.eval()

        # Tokenize input
        inputs = tokenizer(
            description,
            max_length=512,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )

        # Move inputs to device
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Make prediction
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            probabilities = torch.nn.functional.softmax(logits, dim=1)

        # Get predicted class and probabilities
        predicted_class_id = torch.argmax(probabilities, dim=1).item()
        predicted_label = id2label[predicted_class_id]

        # Get probabilities for all classes
        probs_dict = {id2label[i]: float(prob) for i, prob in enumerate(probabilities[0])}

        # Return result
        return {
            'prediction': predicted_label,
            'confidence': float(probabilities[0][predicted_class_id]),
            'probabilities': probs_dict
        }

    except Exception as e:
        print(f"Error in prediction: {str(e)}")
        return {
            'error': str(e),
            'prediction': None,
            'confidence': None,
            'probabilities': {}
        }

# Try to load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, tokenizer, id2label, label2id = load_transformer_model(MODEL_DIR)

if model and tokenizer:
    model_message = '<span style="color:#03dac6">✓ PubMedBERT model loaded successfully!</span>'
    model_source = "PubMedBERT transformer model (fine-tuned for medical trial classification)"
else:
    model_message = '<span style="color:#cf6679">⚠ Using demo mode (model not found)</span>'
    model_source = "Demo mode (functionality limited)"

# Header with accessibility improvements
header = widgets.HTML(
    f"""
    <div class="custom-header" role="banner" aria-label="Application Header">
        <h2 style="margin-top:0; margin-bottom:10px;">
            <span aria-hidden="true">📊</span> Medical Trial Condition Predictor
        </h2>
        <p style="margin:0;">Analyze clinical trial descriptions to predict associated medical conditions</p>
        <p style="margin-top:10px; font-size:0.9em;">
            <b>Model Status:</b> {model_message}<br>
            <b>Source:</b> {model_source}
        </p>
        <div class="model-info">
            <p style="margin:5px 0;">
                <b>Model Information:</b><br>
                • PubMedBERT: Specialized BERT model pre-trained on PubMed abstracts and full-text<br>
                • Fine-tuned for medical trial classification<br>
                • <span class="tooltip">Optimized for recognizing medical conditions from trial descriptions
                    <span class="tooltiptext">This model has been trained specifically to identify medical conditions based on text patterns in clinical trial descriptions</span>
                </span>
            </p>
        </div>
    </div>
    """
)

# Sample dropdown with improved accessibility
sample_dropdown = widgets.Dropdown(
    options=[
        ('Select a sample description', ''),
        ('ALS Sample', 'als_sample'),
        ('Dementia Sample', 'dementia_sample'),
        ('OCD Sample', 'ocd_sample'),
        ('Parkinson\'s Sample', 'parkinsons_sample'),
        ('Schizophrenia Sample', 'schizophrenia_sample')
    ],
    value='',
    description='Example Trials:',
    style={'description_width': 'initial'},
    layout=widgets.Layout(margin='8px 0 16px 0', width='100%')
)

# Help text for the dropdown
sample_help = widgets.HTML(
    """
    <div style="margin: -10px 0 16px 0; font-size: 0.9em; color: rgba(255,255,255,0.7);">
        <span class="tooltip" aria-label="Click for more information">ℹ️
            <span class="tooltiptext">Select from pre-defined examples to see predictions for various conditions</span>
        </span>
    </div>
    """
)

# Create textarea with accessibility attributes
description_input = widgets.Textarea(
    value='',
    placeholder='Enter clinical trial description here...',
    description='Description:',
    disabled=False,
    layout=widgets.Layout(width='100%', height='200px'),
)

# Help text for the textarea
description_help = widgets.HTML(
    """
    <div style="margin: -10px 0 16px 0; font-size: 0.9em; color: rgba(255,255,255,0.7);">
        <span class="tooltip" aria-label="Click for more information">ℹ️
            <span class="tooltiptext">Enter a detailed description of a clinical trial. More details will lead to better predictions.</span>
        </span>
        Enter a clinical trial description to predict the associated medical condition
    </div>
    """
)

# Create buttons with improved accessibility
predict_button = widgets.Button(
    description='Predict',
    button_style='primary',
    tooltip='Analyze and predict condition',
    icon='search',
    layout=widgets.Layout(margin='4px 8px 4px 0')
)

clear_button = widgets.Button(
    description='Clear Input',
    button_style='warning',
    tooltip='Clear input and results',
    icon='trash',
    layout=widgets.Layout(margin='4px 0 4px 8px')
)

# Output area
output_area = widgets.Output()

# Sample descriptions
samples = {
    'als_sample': "This clinical trial evaluates the safety and efficacy of an investigational therapy in people with Amyotrophic Lateral Sclerosis (ALS). ALS is a progressive neurodegenerative disease affecting motor neurons in the brain and spinal cord, leading to muscle weakness and eventual respiratory failure. The study measures changes in the ALSFRS-R score and survival time.",

    'dementia_sample': "This study investigates a new treatment for patients with mild to moderate Alzheimer's disease dementia. The trial evaluates whether the drug can slow cognitive decline and improve daily functioning. Primary outcomes include changes in cognitive performance measured by ADAS-Cog and CDR-SB.",

    'ocd_sample': "This trial examines a combined therapy approach for treatment-resistant Obsessive-Compulsive Disorder (OCD). OCD is characterized by unwanted thoughts and repetitive behaviors that impair functioning. The study measures symptom reduction using the Yale-Brown Obsessive Compulsive Scale.",

    'parkinsons_sample': "This study evaluates a new dopamine agonist for early-stage Parkinson's disease. Parkinson's is characterized by tremor, rigidity, and slowness of movement due to dopamine cell loss. The trial measures changes in the Unified Parkinson's Disease Rating Scale scores.",

    'schizophrenia_sample': "This clinical trial assesses the efficacy of a novel antipsychotic medication for treating schizophrenia symptoms. The study focuses on reducing positive symptoms like hallucinations and delusions, as well as negative symptoms like reduced emotional expression and avolition. Primary outcomes are measured using the PANSS scale over a 12-week treatment period."
}

# Enhanced visualization function for probabilities
def create_probability_plot(probabilities):
    """
    Create an enhanced bar chart for probabilities with transparent background

    Parameters:
    -----------
    probabilities : dict
        Dictionary of class labels and their probabilities

    Returns:
    --------
    str
        Base64 encoded PNG image
    """
    if not probabilities:
        return None

    # Sort probabilities
    sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)
    labels = [item[0] for item in sorted_probs]
    values = [item[1] for item in sorted_probs]

    # Create figure with transparent background
    fig = Figure(figsize=(10, 5), dpi=100)
    fig.patch.set_alpha(0.0)
    ax = fig.add_subplot(111)
    ax.set_facecolor('#21212180')  # Semi-transparent background

    # Create gradient colors based on values
    colors = plt.cm.viridis(np.linspace(0.2, 0.8, len(values)))

    # Create horizontal bar chart
    bars = ax.barh(labels, values, color=colors, alpha=0.9)

    # Add percentage labels
    for i, bar in enumerate(bars):
        width = bar.get_width()
        ax.text(max(width + 0.01, 0.05), bar.get_y() + bar.get_height()/2,
                f'{values[i]:.1%}', va='center', color='white', fontweight='bold')

    # Set titles and labels
    ax.set_title('Prediction Probabilities', color='#bb86fc', fontsize=14, fontweight='bold')
    ax.set_xlabel('Probability', color='white', fontsize=12)
    ax.set_xlim(0, 1.05)

    # Customize grid
    ax.grid(axis='x', linestyle='--', alpha=0.3, color='#dddddd')

    # Remove top and right spines
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_color('#555555')
    ax.spines['bottom'].set_color('#555555')

    # Adjust tick parameters
    ax.tick_params(axis='both', colors='white')

    # Ensure proper layout
    fig.tight_layout()

    # Convert to base64 image
    canvas = FigureCanvas(fig)
    buf = BytesIO()
    canvas.print_png(buf)
    data = base64.b64encode(buf.getbuffer()).decode("ascii")

    return data

def highlight_key_terms_html(text, terms):
    """Highlight key terms in the text with HTML markup"""
    if not terms:
        return text

    highlighted = text
    for term in terms:
        highlighted = highlighted.replace(
            term,
            f'<span class="key-term">{term}</span>'
        )
    return highlighted

from collections import Counter
from transformers import AutoTokenizer


def identify_key_terms(text, prediction=None):
    """
    Identify key medical terms from the input text using transformers.
    """
    try:
        # Tokenize text
        tokens = tokenizer.tokenize(text.lower())

        # Stop words
        stop_words = set(tokenizer.vocab.keys())  # Get stop words from tokenizer

        # Filter words
        filtered_words = [
            word for word in tokens
            if word.isalnum()
            and word not in stop_words
            and len(word) > 3
        ]

        # Get word frequencies
        word_freq = Counter(filtered_words)

        # Get top terms
        key_terms = [term for term, _ in word_freq.most_common(8)]

        return key_terms

    except Exception as e:
        print(f"Error in key term identification: {str(e)}")
        return []


# Define button behaviors
def on_predict_button_clicked(b):
    with output_area:
        clear_output()
        text = description_input.value

        if not text.strip():
            display(HTML('''
                <div class="error-message" role="alert">
                    <span aria-hidden="true">⚠️</span> Please enter a clinical trial description
                </div>
            '''))
            return

        display(HTML('''
            <div class="loading-container" role="status">
                <div class="loading-spinner" aria-hidden="true"></div>
                <p style="text-align:center; color:#e0e0e0;">
                    Analyzing trial description with PubMedBERT...
                </p>
                <span class="sr-only">Processing the request. Please wait...</span>
            </div>
        '''))

        # Here we call our prediction function
        try:
            # Get prediction
            if model and tokenizer:
                result = predict_condition(text, model, tokenizer, id2label, device)
            else:
                # Demo mode - simulate a prediction
                time.sleep(1.5)  # Simulate processing time
                classes = list(id2label.values())
                prediction = np.random.choice(classes)
                probs = np.random.dirichlet(np.ones(len(classes)) * 2)  # Alpha=2 gives more concentration

                # Ensure high probability for the selected class to make it look realistic
                max_idx = classes.index(prediction)
                probs_adjusted = np.zeros_like(probs)
                probs_adjusted[max_idx] = 0.9 + np.random.random() * 0.1  # Between 0.9 and 1.0

                # Distribute remaining probability
                remaining = 1.0 - probs_adjusted[max_idx]
                for i in range(len(probs)):
                    if i != max_idx:
                        probs_adjusted[i] = remaining / (len(probs) - 1)

                result = {
                    'prediction': prediction,
                    'confidence': float(probs_adjusted[max_idx]),
                    'probabilities': {cls: float(prob) for cls, prob in zip(classes, probs_adjusted)}
                }

            clear_output()

            if result['prediction'] is None:
                display(HTML(f'''
                    <div class="error-message" role="alert">
                        <span aria-hidden="true">⚠️</span> Error: {result.get("error", "Unknown error")}
                    </div>
                '''))
                return

            # Identify key terms in the text based on the prediction
            key_terms = identify_key_terms(text, result['prediction'])
            highlighted_text = highlight_key_terms_html(text, key_terms)

            # Create confidence level description
            conf_level = result['confidence']
            if conf_level > 0.95:
                conf_text = "Very High Confidence"
                conf_color = "#03dac6"
            elif conf_level > 0.85:
                conf_text = "High Confidence"
                conf_color = "#2ECC71"
            elif conf_level > 0.70:
                conf_text = "Moderate Confidence"
                conf_color = "#F39C12"
            else:
                conf_text = "Low Confidence"
                conf_color = "#E67E22"

            # Create the probability visualization
            prob_image = create_probability_plot(result['probabilities'])

            # Display prediction result with enhanced UI
            prediction_html = f'''
            <div class="results-panel" role="region" aria-label="Prediction Results">
                <h3 style="margin-top:0; display: flex; justify-content: space-between; align-items: center;">
                    <span>Prediction Results</span>
                    <span style="font-size: 0.8em; background-color: {conf_color}; color: #000; padding: 4px 8px;
                            border-radius: 16px; opacity: 0.9;" class="tooltip">
                        {conf_text}
                        <span class="tooltiptext">Confidence level: {result['confidence']:.1%}</span>
                    </span>
                </h3>

                <div style="display: flex; flex-wrap: wrap; gap: 20px; margin-bottom: 20px;">
                    <!-- Prediction Section -->
                    <div style="flex: a; min-width: 250px;">
                        <p><b>Predicted Condition:</b>
                            <span class="prediction-badge" aria-label="Predicted condition: {result['prediction']}">
                                {result['prediction']}
                            </span>
                        </p>

                        <p><b>Confidence Score:</b>
                            <span style="font-weight: 600; color: {conf_color};">{result['confidence']:.1%}</span>
                        </p>

                        <p><b>Model Used:</b>
                            <span class="tooltip">PubMedBERT
                                <span class="tooltiptext">Specialized BERT model pre-trained on biomedical literature</span>
                            </span>
                        </p>
                    </div>

            '''

            prediction_html += '''
                        </div>
                    </div>
                </div>

                <hr style="border-color:#424242; margin:15px 0; opacity: 0.5;">

                <!-- Probability Visualization -->
                <div class="plot-container" aria-label="Probability distribution chart">
            '''

            if prob_image:
                prediction_html += f'''
                    <img src="data:image/png;base64,{prob_image}"
                        style="width:100%; max-width:800px; background-color: transparent;"
                        alt="Bar chart showing prediction probabilities for each class" />
                '''

            prediction_html += '''
                <div style="margin-top: 20px; font-size: 0.9em; color: #9e9e9e;">
                    <span class="tooltip">
                        Note: This prediction is based on text pattern recognition and should not replace professional medical diagnosis.
                        <span class="tooltiptext">AI predictions should be verified by healthcare professionals</span>
                    </span>
                </div>
            </div>
            '''

            display(HTML(prediction_html))

        except Exception as e:
            clear_output()
            display(HTML(f'''
                <div class="error-message" role="alert">
                    <span aria-hidden="true">⚠️</span> Error making prediction: {str(e)}
                </div>
            '''))

def on_clear_button_clicked(b):
    description_input.value = ''
    with output_area:
        clear_output()
        display(HTML('''
            <div style="text-align: center; color: #9e9e9e; margin: 20px 0;">
                Enter a clinical trial description and click "Predict Condition" to analyze
            </div>
        '''))

def on_sample_selected(change):
    if change['new'] and change['new'] in samples:
        description_input.value = samples[change['new']]
        sample_dropdown.value = ''  # Reset dropdown after selection

# Connect callbacks to widgets
predict_button.on_click(on_predict_button_clicked)
clear_button.on_click(on_clear_button_clicked)
sample_dropdown.observe(on_sample_selected, names='value')

# Create a footer with attribution
footer = widgets.HTML(
    """
    <div class="custom-footer" role="contentinfo">
        <p style="margin-bottom: 5px;">
            <span aria-hidden="true">🔬</span> Medical Trial Classification System | Transformer-powered Condition Prediction
        </p>
        <p style="margin: 0; font-weight: 500; color: #bb86fc;">
            Created by Harry Patria © 2025
        </p>
        <div style="font-size: 0.85em; margin-top: 8px;">
            <span class="tooltip">Accessibility
                <span class="tooltiptext">This application is designed to be screen reader friendly and keyboard navigable</span>
            </span> •
            <span class="tooltip">Privacy
                <span class="tooltiptext">All predictions are processed locally and no data is stored</span>
            </span> •
            <span class="tooltip">Disclaimer
                <span class="tooltiptext">For research purposes only. Not for clinical diagnosis.</span>
            </span>
        </div>
    </div>
    """
)

# Initial instruction in output area
with output_area:
    display(HTML('''
        <div style="text-align: center; color: #9e9e9e; margin: 20px 0;">
            Enter a clinical trial description and click "Predict Condition" to analyze
        </div>
    '''))

# Layout widgets for display with improved organization
button_box = widgets.HBox([predict_button, clear_button])
button_box.layout.justify_content = 'flex-start'
button_box.layout.margin = '8px 0 16px 0'

# Create description section
description_section = widgets.VBox([
    description_input,
    description_help
])

# Assemble the main interface with better spacing
main_box = widgets.VBox([
    header,
    sample_dropdown,
    sample_help,
    description_section,
    button_box,
    output_area,
    footer
], layout=widgets.Layout(
    width='100%',
    max_width='960px',
    margin='0 auto'
))

# Display the interactive interface
display(main_box)

# Add keyboard shortcut help (accessible via keyboard navigation)
keyboard_help = widgets.HTML(
    """
    <div style="position: fixed; bottom: 20px; right: 20px;
            background-color: rgba(30, 30, 30, 0.8); padding: 12px;
            border-radius: 8px; border: 1px solid #444;
            font-size: 0.85em; z-index: 1000; display: none;"
            id="keyboard-shortcuts">
        <h4 style="margin-top: 0; color: #bb86fc;">Keyboard Shortcuts</h4>
        <ul style="padding-left: 20px; margin-bottom: 5px;">
            <li><kbd>Tab</kbd> - Navigate between elements</li>
            <li><kbd>Enter</kbd> - Activate buttons or controls</li>
            <li><kbd>Alt+P</kbd> - Predict (when in text area)</li>
            <li><kbd>Alt+C</kbd> - Clear (when in text area)</li>
            <li><kbd>?</kbd> - Toggle this help</li>
        </ul>
    </div>

    <script>
        // Add keyboard shortcut functionality
        document.addEventListener('keydown', function(e) {
            // Show/hide keyboard help with '?'
            if (e.key === '?') {
                const helpPanel = document.getElementById('keyboard-shortcuts');
                helpPanel.style.display = helpPanel.style.display === 'none' ? 'block' : 'none';
            }

            // Alt+P to predict
            if (e.altKey && e.key === 'p') {
                document.querySelector('.jupyter-button.mod-primary').click();
                e.preventDefault();
            }

            // Alt+C to clear
            if (e.altKey && e.key === 'c') {
                document.querySelector('.jupyter-button.mod-warning').click();
                e.preventDefault();
            }
        });
    </script>
    """
)

display(keyboard_help)

# Add notification for screen reader users
screen_reader_notification = widgets.HTML(
    """
    <div aria-live="polite" id="sr-announcement"
         style="position: absolute; width: 1px; height: 1px;
                padding: 0; margin: -1px; overflow: hidden;
                clip: rect(0, 0, 0, 0); white-space: nowrap;
                border-width: 0;">
        Medical Trial Condition Predictor application loaded.
        Use tab to navigate through the form elements.
    </div>
    """
)

display(screen_reader_notification)

VBox(children=(HTML(value='\n    <div class="custom-header" role="banner" aria-label="Application Header">\n  …

HTML(value='\n    <div style="position: fixed; bottom: 20px; right: 20px; \n            background-color: rgba…

HTML(value='\n    <div aria-live="polite" id="sr-announcement" \n         style="position: absolute; width: 1p…