In [1]:
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.applications.densenet import DenseNet201
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import (Dense, GlobalAveragePooling2D, Conv2D,
                                   Conv2DTranspose, Reshape, Add, Multiply,
                                   MultiHeadAttention, LayerNormalization, Dropout,
                                   Input, Concatenate)
from tensorflow.keras.metrics import Recall, Precision, AUC
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from zipfile import ZipFile
import os

# Install required packages for BioBERT
try:
    import transformers
    from transformers import AutoTokenizer, TFAutoModel
    print("Transformers library already installed")
except ImportError:
    print("Installing transformers library...")
    os.system("pip install transformers")
    import transformers
    from transformers import AutoTokenizer, TFAutoModel

# Mount Google Drive (if using Colab)
try:
    from google.colab import drive
    drive.mount('/content/drive')
    print("Google Drive mounted successfully")
except ImportError:
    print("Not running in Colab, skipping drive mount")

Transformers library already installed
Mounted at /content/drive
Google Drive mounted successfully


In [2]:
# BioBERT Disease Embedding Class
class BioBERTDiseaseEmbedder:
    """Class to generate disease embeddings using BioBERT"""

    def __init__(self, model_name='dmis-lab/biobert-base-cased-v1.1', embedding_dim=512):
        self.model_name = model_name
        self.embedding_dim = embedding_dim
        self.tokenizer = None
        self.model = None
        self.disease_embeddings = None

    def load_model(self):
        """Load BioBERT tokenizer and model"""
        print("Loading BioBERT model...")
        try:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
            # Try loading with PyTorch conversion first
            self.model = TFAutoModel.from_pretrained(self.model_name, from_pt=True)
            print("BioBERT model loaded successfully (converted from PyTorch)")
        except Exception as e:
            print(f"Error loading BioBERT model: {e}")
            print("Falling back to alternative biomedical model...")
            # Fallback to a model with native TensorFlow support
            try:
                self.model_name = 'bert-base-uncased'  # Fallback option
                self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
                self.model = TFAutoModel.from_pretrained(self.model_name)
                print("Fallback model loaded successfully")
            except Exception as e2:
                print(f"Fallback model also failed: {e2}")
                raise e2

    def generate_disease_embeddings(self, disease_labels):
        """Generate embeddings for disease labels"""
        if self.tokenizer is None or self.model is None:
            self.load_model()

        embeddings = []

        print("Generating disease embeddings...")
        for disease in disease_labels:
            # Tokenize the disease name
            inputs = self.tokenizer(
                disease,
                return_tensors='tf',
                padding=True,
                truncation=True,
                max_length=64
            )

            # Get BioBERT embeddings
            outputs = self.model(**inputs)

            # Use [CLS] token embedding (first token)
            cls_embedding = outputs.last_hidden_state[:, 0, :]  # Shape: (1, 768)

            embeddings.append(cls_embedding.numpy())

        # Stack all embeddings
        disease_embeddings = np.vstack(embeddings)  # Shape: (num_diseases, 768)

        # Project to desired embedding dimension if needed
        if disease_embeddings.shape[1] != self.embedding_dim:
            projection_matrix = np.random.normal(
                0, 0.02, (disease_embeddings.shape[1], self.embedding_dim)
            )
            disease_embeddings = disease_embeddings @ projection_matrix

        self.disease_embeddings = disease_embeddings
        print(f"Disease embeddings generated: {disease_embeddings.shape}")

        return disease_embeddings

In [7]:
# Data extraction
def extract_data():
    """Extract dataset from zip file"""
    try:
        with ZipFile("/content/drive/MyDrive/mured.zip", 'r') as zip_file:
            zip_file.extractall()
            print("Data extraction completed")
    except FileNotFoundError:
        print("Warning: Zip file not found. Please ensure data is available.")

# Data loading and preprocessing
def load_and_prepare_data():
    """Load and prepare training and test data"""
    try:
        # Load data
        train_data = pd.read_csv('/content/drive/MyDrive/train_data_modified.csv')
        test_data = pd.read_csv('/content/drive/MyDrive/test_data_modified.csv')

        print(train_data.head())
        print(test_data.head())

        # Sample data for training (adjust as needed)
        train_data = train_data.sample(frac=1, random_state=42)[:1600]
        test_data = test_data[:320]

        print(f"Training data shape: {train_data.shape}")
        print(f"Test data shape: {test_data.shape}")

        return train_data, test_data
    except FileNotFoundError:
        print("Error: CSV files not found. Please check file paths.")
        return None, None

# Define disease labels
DISEASE_LABELS = ['DR', 'NORMAL', 'MH', 'ODC', 'TSLN', 'ARMD', 'DN', 'MYA',
                  'BRVO', 'ODP', 'CRVO', 'CNV', 'RS', 'ODE', 'LS', 'CSR',
                  'HTR', 'ASR', 'CRS', 'OTHER']

DISEASE_LABELS_FULL = ['DIABETIC RETINOPATHY', 'NORMAL', 'MEDIA HAZE',
                       'OPTIC DISC COLOBOMA', 'TESSELLATION',
                       'AGE RELATED MACULAR DEGENERATION', 'DRUSEN', 'MYOPIA',
                       'BRANCH RETINAL VEIN OCCLUSION', 'OPTIC DISC PALLOR',
                       'CENTRAL RETINAL VEIN OCCLUSION', 'CHOROIDAL NEOVASCULARIZATION',
                       'RETINITIS', 'OPTIC DISC EDEMA', 'LASER SCARS',
                       'CENTRAL SEROUS RETINOPATHY', 'HYPERTENSIVE RETINOPATHY',
                       'ARTIFICIAL SILICON RETINA', 'CHORIORETINITIS', 'OTHER']

In [None]:
# Data generators
def create_data_generators(train_data, test_data, batch_size=16, img_size=(320, 320)):
    """Create data generators for training, validation, and testing"""

    # Training data generator with augmentation
    train_datagen = ImageDataGenerator(
        rescale=1./255,
        validation_split=0.2,
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest'
    )

    # Test data generator (no augmentation)
    test_datagen = ImageDataGenerator(rescale=1./255)

    train_generator = train_datagen.flow_from_dataframe(
        dataframe=train_data,
        directory="/content/images/images",
        x_col="ID_2",
        y_col=DISEASE_LABELS,
        class_mode='raw',
        batch_size=batch_size,
        target_size=img_size,
        subset='training'
    )

    val_generator = train_datagen.flow_from_dataframe(
        dataframe=train_data,
        directory="/content/images/images",
        x_col="ID_2",
        y_col=DISEASE_LABELS,
        class_mode='raw',
        batch_size=batch_size,
        target_size=img_size,
        subset='validation'
    )

    test_generator = test_datagen.flow_from_dataframe(
        dataframe=test_data,
        directory="/content/images/images",
        x_col="ID_2",
        y_col=DISEASE_LABELS,
        class_mode='raw',
        batch_size=batch_size,
        target_size=img_size,
        shuffle=False
    )

    return train_generator, val_generator, test_generator

In [None]:
# Custom layers
class FullyConnectedLayer(tf.keras.layers.Layer):
    """Custom fully connected layer for transformer-like architecture"""

    def __init__(self, embedding_dim, fully_connected_dim, dropout_rate=0.1, **kwargs):
        super(FullyConnectedLayer, self).__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.fully_connected_dim = fully_connected_dim
        self.dropout_rate = dropout_rate

        self.dense1 = Dense(fully_connected_dim, activation='relu')
        self.dense2 = Dense(embedding_dim)
        self.dropout = Dropout(dropout_rate)

    def call(self, x, training=False):
        x = self.dense1(x)
        x = self.dropout(x, training=training)
        return self.dense2(x)

    def get_config(self):
        config = super().get_config()
        config.update({
            "embedding_dim": self.embedding_dim,
            "fully_connected_dim": self.fully_connected_dim,
            "dropout_rate": self.dropout_rate
        })
        return config

class EncoderLayer(tf.keras.layers.Layer):
    """Transformer encoder layer with multi-head attention"""

    def __init__(self, embedding_dim, num_heads, fully_connected_dim,
                 dropout_rate=0.1, **kwargs):
        super(EncoderLayer, self).__init__(**kwargs)

        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.fully_connected_dim = fully_connected_dim
        self.dropout_rate = dropout_rate

        self.mha = MultiHeadAttention(
            num_heads=num_heads,
            key_dim=embedding_dim,
            dropout=dropout_rate
        )

        self.ffn = FullyConnectedLayer(embedding_dim, fully_connected_dim, dropout_rate)
        self.layernorm1 = LayerNormalization(epsilon=1e-6)
        self.layernorm2 = LayerNormalization(epsilon=1e-6)
        self.dropout1 = Dropout(dropout_rate)

    def call(self, inputs, training=False):
        # Multi-head attention
        attn_output = self.mha(inputs, inputs, inputs, training=training)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)

        # Feed forward network
        ffn_output = self.ffn(out1, training=training)
        encoder_output = self.layernorm2(out1 + ffn_output)

        return encoder_output

    def get_config(self):
        config = super().get_config()
        config.update({
            "embedding_dim": self.embedding_dim,
            "num_heads": self.num_heads,
            "fully_connected_dim": self.fully_connected_dim,
            "dropout_rate": self.dropout_rate
        })
        return config


In [None]:
class GlobalMeanPoolingLayer(tf.keras.layers.Layer):
    """Custom layer for global mean pooling along sequence dimension"""

    def __init__(self, **kwargs):
        super(GlobalMeanPoolingLayer, self).__init__(**kwargs)

    def call(self, inputs):
        return tf.reduce_mean(inputs, axis=1)

    def get_config(self):
        return super().get_config()

class DiseaseEmbeddingExpansionLayer(tf.keras.layers.Layer):
    """Layer to expand disease embeddings based on batch size"""

    def __init__(self, num_classes, **kwargs):
        super(DiseaseEmbeddingExpansionLayer, self).__init__(**kwargs)
        self.num_classes = num_classes

    def call(self, inputs):
        # inputs is a tuple of (batch_reference, disease_embeddings)
        batch_reference, disease_embeddings = inputs
        batch_size = tf.shape(batch_reference)[0]

        # Expand disease embeddings to match batch size
        expanded = tf.expand_dims(disease_embeddings, 0)
        tiled = tf.tile(expanded, [batch_size, 1, 1])
        return tiled

    def get_config(self):
        config = super().get_config()
        config.update({"num_classes": self.num_classes})
        return config
    """Layer to provide BioBERT disease embeddings to the model"""

class BioBERTDiseaseEmbeddingLayer(tf.keras.layers.Layer):
    """Layer to provide BioBERT disease embeddings to the model"""

    def __init__(self, disease_embeddings, embedding_dim=512, **kwargs):
        super(BioBERTDiseaseEmbeddingLayer, self).__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_classes = disease_embeddings.shape[0]

        # Store disease embeddings as a weight (trainable parameter)
        self.disease_embeddings_weight = None
        self.disease_embeddings_np = disease_embeddings.astype(np.float32)

    def build(self, input_shape):
        # Create the embeddings as a trainable weight
        self.disease_embeddings_weight = self.add_weight(
            name='disease_embeddings',
            shape=(self.num_classes, self.embedding_dim),
            initializer='zeros',
            trainable=False  # Keep as non-trainable since they're pre-computed
        )
        # Initialize with BioBERT embeddings
        self.disease_embeddings_weight.assign(self.disease_embeddings_np)
        super().build(input_shape)

    def call(self, inputs):
        # Get batch size from input tensor
        batch_size = tf.shape(inputs)[0]

        # Expand disease embeddings to match batch size
        # Shape: (num_classes, embedding_dim) -> (batch_size, num_classes, embedding_dim)
        expanded_embeddings = tf.expand_dims(self.disease_embeddings_weight, 0)
        tiled_embeddings = tf.tile(expanded_embeddings, [batch_size, 1, 1])

        return tiled_embeddings

    def get_config(self):
        config = super().get_config()
        config.update({
            "embedding_dim": self.embedding_dim,
            "num_classes": self.num_classes
        })
        return config

In [None]:
# Model building functions
def create_base_model(input_shape=(320, 320, 3)):
    """Create base DenseNet201 model"""
    base_model = DenseNet201(
        include_top=False,
        weights='imagenet',
        input_shape=input_shape
    )
    return base_model

def create_multi_scale_features(base_model):
    """Create multi-scale feature extraction"""

    # Get features from different layers
    high_level_features = base_model.output  # Shape: (batch, 10, 10, 1920)
    low_level_output = base_model.layers[-228].output  # Shape: (batch, 20, 20, 1792)

    # Create models for different feature levels
    low_level_model = Model(inputs=base_model.input, outputs=low_level_output)

    # Multi-Scale Feature Module (MSFM)
    f_h = Conv2D(512, (1, 1), activation='relu', name='high_level_conv')(high_level_features)
    f_l = Conv2D(512, (1, 1), activation='relu', name='low_level_conv')(low_level_output)

    # Upsample high-level features
    f_up = Conv2DTranspose(512, kernel_size=(4, 4), strides=(2, 2),
                          padding='same', activation='relu', name='upsample')(f_h)

    # Combine features
    f_combined = Add(name='feature_add')([f_up, f_l])
    f_refined = Conv2D(512, 3, padding='same', activation='relu', name='refined_conv')(f_combined)

    # Channel Attention Module (CAM)
    f_gap = GlobalAveragePooling2D()(f_refined)
    f_gap_reshaped = Reshape((1, 1, 512))(f_gap)

    f_attention = Conv2D(512, (1, 1), activation='relu')(f_gap_reshaped)
    f_attention = Conv2D(512, (1, 1), activation='sigmoid')(f_attention)

    f_attended = Multiply()([f_refined, f_attention])
    f_final = Add()([f_refined, f_attended])

    return f_final, f_h

In [None]:
def create_biobert_enhanced_model(input_shape=(320, 320, 3), num_classes=20, disease_embeddings=None):
    """Create the complete model with BioBERT disease embeddings"""

    # Input layer
    image_input = Input(shape=input_shape, name='image_input')

    # Base DenseNet201 for visual features
    base_model = DenseNet201(
        include_top=False,
        weights='imagenet',
        input_shape=input_shape
    )

    # Get visual features
    visual_features = base_model(image_input)

    # Multi-scale feature processing
    f_h = Conv2D(512, (1, 1), activation='relu', name='high_level_conv')(visual_features)

    # Try to get low-level features (with error handling)
    try:
        low_level_output = base_model.layers[-228].output
        low_level_model = Model(inputs=base_model.input, outputs=low_level_output)
        low_level_features = low_level_model(image_input)
        f_l = Conv2D(512, (1, 1), activation='relu', name='low_level_conv')(low_level_features)

        # Upsample high-level features
        f_up = Conv2DTranspose(512, kernel_size=(4, 4), strides=(2, 2),
                              padding='same', activation='relu', name='upsample')(f_h)

        # Combine features
        f_combined = Add(name='feature_add')([f_up, f_l])
        f_refined = Conv2D(512, 3, padding='same', activation='relu', name='refined_conv')(f_combined)
    except:
        # Fallback: just use high-level features
        print("Warning: Could not access low-level features, using high-level only")
        f_refined = Conv2D(512, 3, padding='same', activation='relu', name='refined_conv')(f_h)

    # Global pooling to get visual feature vector
    visual_pooled = GlobalAveragePooling2D(name='visual_gap')(f_refined)

    # Reshape visual features for transformer: (batch_size, 1, 512)
    visual_reshaped = Reshape((1, 512), name='visual_reshape')(visual_pooled)

    # BioBERT Disease Embeddings
    if disease_embeddings is not None:
        # Create a constant layer for disease embeddings
        disease_embeddings_const = tf.constant(disease_embeddings.astype(np.float32))

        # Create a layer to expand disease embeddings based on batch size
        class DiseaseEmbeddingLayer(tf.keras.layers.Layer):
            def __init__(self, embeddings, **kwargs):
                super().__init__(**kwargs)
                self.embeddings = embeddings

            def call(self, inputs):
                batch_size = tf.shape(inputs)[0]
                expanded = tf.expand_dims(self.embeddings, 0)
                return tf.tile(expanded, [batch_size, 1, 1])

            def get_config(self):
                config = super().get_config()
                config.update({
                    "embeddings": self.embeddings.numpy().tolist()
                })
                return config

            @classmethod
            def from_config(cls, config):
                embeddings = tf.constant(config.pop("embeddings"))
                return cls(embeddings, **config)

        disease_embedding_layer = DiseaseEmbeddingLayer(disease_embeddings_const, name='disease_embeddings')
        disease_features = disease_embedding_layer(image_input)
    else:
        # Fallback to learnable embeddings
        embedding_layer = tf.keras.layers.Embedding(num_classes, 512, name='learnable_disease_embeddings')
        indices = tf.range(num_classes)

        # Create a layer to handle the embedding expansion
        class DiseaseEmbeddingLayer(tf.keras.layers.Layer):
            def __init__(self, embeddings, **kwargs):
                super().__init__(**kwargs)
                self.embeddings = embeddings

            def call(self, inputs):
                batch_size = tf.shape(inputs)[0]
                expanded = tf.expand_dims(self.embeddings, 0)
                return tf.tile(expanded, [batch_size, 1, 1])

            def get_config(self):
                config = super().get_config()
                config.update({
                    "embeddings": self.embeddings.numpy().tolist()
                })
                return config

            @classmethod
            def from_config(cls, config):
                embeddings = tf.constant(config.pop("embeddings"))
                return cls(embeddings, **config)

        learnable_layer = LearnableEmbeddingLayer(embedding_layer, indices, name='learnable_expansion')
        disease_features = learnable_layer(image_input)

    # Concatenate visual and disease features
    combined_features = Concatenate(axis=1, name='feature_concat')([visual_reshaped, disease_features])

    # Transformer encoder layers
    transformer1 = EncoderLayer(
        embedding_dim=512,
        num_heads=8,
        fully_connected_dim=2048,
        dropout_rate=0.1,
        name='transformer_encoder_1'
    )
    encoded_features1 = transformer1(combined_features)

    # Second transformer layer
    transformer2 = EncoderLayer(
        embedding_dim=512,
        num_heads=8,
        fully_connected_dim=2048,
        dropout_rate=0.1,
        name='transformer_encoder_2'
    )
    encoded_features2 = transformer2(encoded_features1)

    # Global pooling of all tokens using custom layer
    global_pool = GlobalMeanPoolingLayer(name='global_mean_pool')
    final_features = global_pool(encoded_features2)

    # Classification head
    x = Dense(1024, activation='relu', name='classifier_dense1')(final_features)
    x = Dropout(0.5, name='classifier_dropout1')(x)
    x = Dense(512, activation='relu', name='classifier_dense2')(x)
    x = Dropout(0.3, name='classifier_dropout2')(x)
    x = Dense(256, activation='relu', name='classifier_dense3')(x)
    x = Dropout(0.2, name='classifier_dropout3')(x)

    # Final predictions
    predictions = Dense(num_classes, activation='sigmoid', name='predictions')(x)

    # Create final model
    model = Model(inputs=image_input, outputs=predictions, name='BioBERT_Medical_Classifier')

    return model

In [None]:
def create_complete_model(input_shape=(320, 320, 3), num_classes=20, disease_embeddings=None):
    """Create the complete model with BioBERT integration"""
    return create_biobert_enhanced_model(input_shape, num_classes, disease_embeddings)

def compile_model(model, learning_rate=0.001):
    """Compile the model with appropriate optimizer and metrics"""

    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='binary_crossentropy',
        metrics=[
            'accuracy',
            Precision(name='precision'),
            Recall(name='recall'),
            AUC(name='auc')
        ]
    )

    return model

In [None]:
def create_callbacks(model_save_path='best_biobert_model.h5'):
    """Create training callbacks"""

    callbacks = [
        EarlyStopping(
            monitor='val_loss',
            patience=7,
            restore_best_weights=True,
            verbose=1
        ),
        ModelCheckpoint(
            model_save_path,
            monitor='val_auc',
            save_best_only=True,
            mode='max',
            verbose=1
        )
    ]

    return callbacks

In [None]:
def train_model(model, train_gen, val_gen, epochs=15, callbacks=None):
    """Train the model"""

    history = model.fit(
        train_gen,
        steps_per_epoch=len(train_gen),
        epochs=epochs,
        validation_data=val_gen,
        validation_steps=len(val_gen),
        callbacks=callbacks,
        verbose=1
    )

    return history

In [None]:
def evaluate_model(model, test_gen):
    """Evaluate the model on test data"""

    test_loss, test_accuracy, test_precision, test_recall, test_auc = model.evaluate(
        test_gen,
        steps=len(test_gen),
        verbose=1
    )

    print(f"\nTest Results:")
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    print(f"Test Precision: {test_precision:.4f}")
    print(f"Test Recall: {test_recall:.4f}")
    print(f"Test AUC: {test_auc:.4f}")

    return {
        'test_loss': test_loss,
        'test_accuracy': test_accuracy,
        'test_precision': test_precision,
        'test_recall': test_recall,
        'test_auc': test_auc
    }

In [None]:
# Main execution function
def main():
    """Main function to run the complete pipeline with BioBERT"""

    print("Starting Medical Image Classification Pipeline with BioBERT...")

    # Extract data
    extract_data()

    # Load and prepare data
    train_data, test_data = load_and_prepare_data()
    if train_data is None or test_data is None:
        print("Failed to load data. Exiting.")
        return

    # Generate BioBERT disease embeddings
    print("Generating BioBERT disease embeddings...")
    biobert_embedder = BioBERTDiseaseEmbedder(embedding_dim=512)
    disease_embeddings = biobert_embedder.generate_disease_embeddings(DISEASE_LABELS_FULL)

    # Create data generators
    train_gen, val_gen, test_gen = create_data_generators(train_data, test_data)

    print("Data generators created successfully")
    print(f"Training samples: {train_gen.n}")
    print(f"Validation samples: {val_gen.n}")
    print(f"Test samples: {test_gen.n}")

    # Create and compile model with BioBERT embeddings
    print("Creating BioBERT-enhanced model...")
    model = create_complete_model(
        input_shape=(320, 320, 3),
        num_classes=20,
        disease_embeddings=disease_embeddings
    )
    model = compile_model(model, learning_rate=0.0001)  # Lower learning rate for stability

    print("Model created and compiled successfully")
    print(f"Model parameters: {model.count_params():,}")

    # Print model summary
    model.summary()

    # Create callbacks
    callbacks = create_callbacks('best_biobert_medical_model.h5')

    # Train model
    print("Starting training...")
    history = train_model(
        model,
        train_gen,
        val_gen,
        epochs=30,
        callbacks=callbacks
    )

    # Evaluate model
    print("Evaluating model...")
    test_results = evaluate_model(model, test_gen)

    print("Training completed successfully!")
    print("\nArchitecture Summary:")
    print("1. Visual Features: DenseNet201 + Multi-scale features → 512-dim vectors")
    print("2. BioBERT Disease Embeddings → 512-dim vectors")
    print("3. Combined features fed to Transformer encoders")
    print("4. Final classification predictions")

    return model, history, test_results, disease_embeddings

# Run the pipeline
if __name__ == "__main__":
    model, history, results, embeddings = main()

Transformers library already installed
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Google Drive mounted successfully
Starting Medical Image Classification Pipeline with BioBERT...
Data extraction completed
Training data shape: (1600, 23)
Test data shape: (320, 23)
Generating BioBERT disease embeddings...
Loading BioBERT model...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceCl

BioBERT model loaded successfully (converted from PyTorch)
Generating disease embeddings...
Disease embeddings generated: (20, 512)
Found 1280 validated image filenames.
Found 320 validated image filenames.
Found 320 validated image filenames.
Data generators created successfully
Training samples: 1280
Validation samples: 320
Test samples: 320
Creating BioBERT-enhanced model...
Model created and compiled successfully
Model parameters: 48,512,340


Starting training...


  self._warn_if_super_not_called()


Epoch 1/30
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.1032 - auc: 0.6014 - loss: 0.3405 - precision: 0.0985 - recall: 0.1068

  self._warn_if_super_not_called()



Epoch 1: val_auc improved from -inf to 0.73143, saving model to best_biobert_medical_model.h5




[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m826s[0m 3s/step - accuracy: 0.1036 - auc: 0.6018 - loss: 0.3397 - precision: 0.0988 - recall: 0.1062 - val_accuracy: 0.2125 - val_auc: 0.7314 - val_loss: 0.2097 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00
Epoch 2/30
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.1448 - auc: 0.6640 - loss: 0.2435 - precision: 0.1326 - recall: 0.0215
Epoch 2: val_auc improved from 0.73143 to 0.74545, saving model to best_biobert_medical_model.h5




[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m210s[0m 2s/step - accuracy: 0.1448 - auc: 0.6641 - loss: 0.2435 - precision: 0.1330 - recall: 0.0215 - val_accuracy: 0.2375 - val_auc: 0.7454 - val_loss: 0.2084 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00
Epoch 3/30
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.1541 - auc: 0.6787 - loss: 0.2385 - precision: 0.2221 - recall: 0.0249
Epoch 3: val_auc did not improve from 0.74545
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m175s[0m 2s/step - accuracy: 0.1543 - auc: 0.6789 - loss: 0.2384 - precision: 0.2223 - recall: 0.0250 - val_accuracy: 0.2750 - val_auc: 0.7393 - val_loss: 0.2079 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00
Epoch 4/30
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.1910 - auc: 0.7056 - l



[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m208s[0m 3s/step - accuracy: 0.1914 - auc: 0.7058 - loss: 0.2286 - precision: 0.2568 - recall: 0.0257 - val_accuracy: 0.3406 - val_auc: 0.7544 - val_loss: 0.2053 - val_precision: 0.6818 - val_recall: 0.0758
Epoch 5/30
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.2598 - auc: 0.7478 - loss: 0.2175 - precision: 0.4206 - recall: 0.0773
Epoch 5: val_auc improved from 0.75445 to 0.77716, saving model to best_biobert_medical_model.h5




[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m217s[0m 3s/step - accuracy: 0.2600 - auc: 0.7478 - loss: 0.2175 - precision: 0.4211 - recall: 0.0774 - val_accuracy: 0.3625 - val_auc: 0.7772 - val_loss: 0.2011 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00
Epoch 6/30
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.2982 - auc: 0.7653 - loss: 0.2112 - precision: 0.4811 - recall: 0.0648
Epoch 6: val_auc improved from 0.77716 to 0.80519, saving model to best_biobert_medical_model.h5




[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m212s[0m 3s/step - accuracy: 0.2985 - auc: 0.7653 - loss: 0.2112 - precision: 0.4818 - recall: 0.0652 - val_accuracy: 0.4031 - val_auc: 0.8052 - val_loss: 0.1904 - val_precision: 0.5347 - val_recall: 0.1364
Epoch 7/30
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.3499 - auc: 0.7627 - loss: 0.2077 - precision: 0.4983 - recall: 0.1168
Epoch 7: val_auc did not improve from 0.80519
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m212s[0m 2s/step - accuracy: 0.3498 - auc: 0.7628 - loss: 0.2077 - precision: 0.4984 - recall: 0.1167 - val_accuracy: 0.3656 - val_auc: 0.7955 - val_loss: 0.1951 - val_precision: 0.6111 - val_recall: 0.0833
Epoch 8/30
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.3433 - auc: 0.7838 - loss: 0.2052 - pr



[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m263s[0m 3s/step - accuracy: 0.3336 - auc: 0.7886 - loss: 0.2019 - precision: 0.5275 - recall: 0.1178 - val_accuracy: 0.4125 - val_auc: 0.8294 - val_loss: 0.1809 - val_precision: 0.6463 - val_recall: 0.1338
Epoch 10/30
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.3649 - auc: 0.8044 - loss: 0.1955 - precision: 0.5302 - recall: 0.1322
Epoch 10: val_auc did not improve from 0.82944
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m163s[0m 2s/step - accuracy: 0.3650 - auc: 0.8044 - loss: 0.1955 - precision: 0.5304 - recall: 0.1322 - val_accuracy: 0.4062 - val_auc: 0.8099 - val_loss: 0.1889 - val_precision: 0.5565 - val_recall: 0.1616
Epoch 11/30
[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - accuracy: 0.3960 - auc: 0.8015 - loss: 0.1944 -



[1m80/80[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m234s[0m 3s/step - accuracy: 0.4277 - auc: 0.8184 - loss: 0.1883 - precision: 0.6033 - recall: 0.1785 - val_accuracy: 0.4344 - val_auc: 0.8388 - val_loss: 0.1825 - val_precision: 0.6375 - val_recall: 0.2576
Epoch 16: early stopping
Restoring model weights from the end of the best epoch: 9.
Evaluating model...


  self._warn_if_super_not_called()


[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 1s/step - accuracy: 0.4439 - auc: 0.8596 - loss: 0.1744 - precision: 0.6034 - recall: 0.1460

Test Results:
Test Loss: 0.1830
Test Accuracy: 0.4062
Test Precision: 0.5281
Test Recall: 0.1158
Test AUC: 0.8402
Training completed successfully!

Architecture Summary:
1. Visual Features: DenseNet201 + Multi-scale features → 512-dim vectors
2. BioBERT Disease Embeddings → 512-dim vectors
3. Combined features fed to Transformer encoders
4. Final classification predictions


In [None]:
# # import numpy as np
# # import pandas as pd
# # import tensorflow as tf
# # from tensorflow.keras.applications.densenet import DenseNet201
# # from tensorflow.keras.preprocessing.image import ImageDataGenerator
# # from tensorflow.keras.models import Model
# # from tensorflow.keras.optimizers import Adam
# # from tensorflow.keras.layers import (Dense, GlobalAveragePooling2D, Conv2D,
# #                                    Conv2DTranspose, Reshape, Add, Multiply,
# #                                    MultiHeadAttention, LayerNormalization, Dropout,
# #                                    Input, Concatenate)
# # from tensorflow.keras.metrics import Recall, Precision, AUC
# # from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
# # from zipfile import ZipFile
# # import os

# # # Install required packages for BioBERT
# # try:
# #     import transformers
# #     from transformers import AutoTokenizer, TFAutoModel
# #     print("Transformers library already installed")
# # except ImportError:
# #     print("Installing transformers library...")
# #     os.system("pip install transformers")
# #     import transformers
# #     from transformers import AutoTokenizer, TFAutoModel

# # # Mount Google Drive (if using Colab)
# # try:
# #     from google.colab import drive
# #     drive.mount('/content/drive')
# #     print("Google Drive mounted successfully")
# # except ImportError:
# #     print("Not running in Colab, skipping drive mount")

# # # BioBERT Disease Embedding Class
# # class BioBERTDiseaseEmbedder:
# #     """Class to generate disease embeddings using BioBERT"""

# #     def __init__(self, model_name='dmis-lab/biobert-base-cased-v1.1', embedding_dim=512):
# #         self.model_name = model_name
# #         self.embedding_dim = embedding_dim
# #         self.tokenizer = None
# #         self.model = None
# #         self.disease_embeddings = None

# #     def load_model(self):
# #         """Load BioBERT tokenizer and model"""
# #         print("Loading BioBERT model...")
# #         try:
# #             self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# #             # Try loading with PyTorch conversion first
# #             self.model = TFAutoModel.from_pretrained(self.model_name, from_pt=True)
# #             print("BioBERT model loaded successfully (converted from PyTorch)")
# #         except Exception as e:
# #             print(f"Error loading BioBERT model: {e}")
# #             print("Falling back to alternative biomedical model...")
# #             # Fallback to a model with native TensorFlow support
# #             try:
# #                 self.model_name = 'bert-base-uncased'  # Fallback option
# #                 self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
# #                 self.model = TFAutoModel.from_pretrained(self.model_name)
# #                 print("Fallback model loaded successfully")
# #             except Exception as e2:
# #                 print(f"Fallback model also failed: {e2}")
# #                 raise e2

# #     def generate_disease_embeddings(self, disease_labels):
# #         """Generate embeddings for disease labels"""
# #         if self.tokenizer is None or self.model is None:
# #             self.load_model()

# #         embeddings = []

# #         print("Generating disease embeddings...")
# #         for disease in disease_labels:
# #             # Tokenize the disease name
# #             inputs = self.tokenizer(
# #                 disease,
# #                 return_tensors='tf',
# #                 padding=True,
# #                 truncation=True,
# #                 max_length=64
# #             )

# #             # Get BioBERT embeddings
# #             outputs = self.model(**inputs)

# #             # Use [CLS] token embedding (first token)
# #             cls_embedding = outputs.last_hidden_state[:, 0, :]  # Shape: (1, 768)

# #             embeddings.append(cls_embedding.numpy())

# #         # Stack all embeddings
# #         disease_embeddings = np.vstack(embeddings)  # Shape: (num_diseases, 768)

# #         # Project to desired embedding dimension if needed
# #         if disease_embeddings.shape[1] != self.embedding_dim:
# #             projection_matrix = np.random.normal(
# #                 0, 0.02, (disease_embeddings.shape[1], self.embedding_dim)
# #             )
# #             disease_embeddings = disease_embeddings @ projection_matrix

# #         self.disease_embeddings = disease_embeddings
# #         print(f"Disease embeddings generated: {disease_embeddings.shape}")

# #         return disease_embeddings

# # # Data extraction
# # def extract_data():
# #     """Extract dataset from zip file"""
# #     try:
# #         with ZipFile("/content/drive/MyDrive/mured.zip", 'r') as zip_file:
# #             zip_file.extractall()
# #             print("Data extraction completed")
# #     except FileNotFoundError:
# #         print("Warning: Zip file not found. Please ensure data is available.")

# # # Data loading and preprocessing
# # def load_and_prepare_data():
# #     """Load and prepare training and test data"""
# #     try:
# #         # Load data
# #         train_data = pd.read_csv('/content/drive/MyDrive/train_data_modified.csv')
# #         test_data = pd.read_csv('/content/drive/MyDrive/test_data_modified.csv')

# #         # Sample data for training (adjust as needed)
# #         train_data = train_data.sample(frac=1, random_state=42)[:1600]
# #         test_data = test_data[:320]

# #         print(f"Training data shape: {train_data.shape}")
# #         print(f"Test data shape: {test_data.shape}")

# #         return train_data, test_data
# #     except FileNotFoundError:
# #         print("Error: CSV files not found. Please check file paths.")
# #         return None, None

# # # Define disease labels
# # DISEASE_LABELS = ['DR', 'NORMAL', 'MH', 'ODC', 'TSLN', 'ARMD', 'DN', 'MYA',
# #                   'BRVO', 'ODP', 'CRVO', 'CNV', 'RS', 'ODE', 'LS', 'CSR',
# #                   'HTR', 'ASR', 'CRS', 'OTHER']

# # DISEASE_LABELS_FULL = ['DIABETIC RETINOPATHY', 'NORMAL', 'MEDIA HAZE',
# #                        'OPTIC DISC COLOBOMA', 'TESSELLATION',
# #                        'AGE RELATED MACULAR DEGENERATION', 'DRUSEN', 'MYOPIA',
# #                        'BRANCH RETINAL VEIN OCCLUSION', 'OPTIC DISC PALLOR',
# #                        'CENTRAL RETINAL VEIN OCCLUSION', 'CHOROIDAL NEOVASCULARIZATION',
# #                        'RETINITIS', 'OPTIC DISC EDEMA', 'LASER SCARS',
# #                        'CENTRAL SEROUS RETINOPATHY', 'HYPERTENSIVE RETINOPATHY',
# #                        'ARTIFICIAL SILICON RETINA', 'CHORIORETINITIS', 'OTHER']

# # # Data generators
# # def create_data_generators(train_data, test_data, batch_size=16, img_size=(320, 320)):
# #     """Create data generators for training, validation, and testing"""

# #     # Training data generator with augmentation
# #     train_datagen = ImageDataGenerator(
# #         rescale=1./255,
# #         validation_split=0.2,
# #         rotation_range=40,
# #         width_shift_range=0.2,
# #         height_shift_range=0.2,
# #         shear_range=0.2,
# #         zoom_range=0.2,
# #         horizontal_flip=True,
# #         fill_mode='nearest'
# #     )

# #     # Test data generator (no augmentation)
# #     test_datagen = ImageDataGenerator(rescale=1./255)

# #     train_generator = train_datagen.flow_from_dataframe(
# #         dataframe=train_data,
# #         directory="/content/images/images",
# #         x_col="ID_2",
# #         y_col=DISEASE_LABELS,
# #         class_mode='raw',
# #         batch_size=batch_size,
# #         target_size=img_size,
# #         subset='training'
# #     )

# #     val_generator = train_datagen.flow_from_dataframe(
# #         dataframe=train_data,
# #         directory="/content/images/images",
# #         x_col="ID_2",
# #         y_col=DISEASE_LABELS,
# #         class_mode='raw',
# #         batch_size=batch_size,
# #         target_size=img_size,
# #         subset='validation'
# #     )

# #     test_generator = test_datagen.flow_from_dataframe(
# #         dataframe=test_data,
# #         directory="/content/images/images",
# #         x_col="ID_2",
# #         y_col=DISEASE_LABELS,
# #         class_mode='raw',
# #         batch_size=batch_size,
# #         target_size=img_size,
# #         shuffle=False
# #     )

# #     return train_generator, val_generator, test_generator

# # # Custom layers
# # class FullyConnectedLayer(tf.keras.layers.Layer):
# #     """Custom fully connected layer for transformer-like architecture"""

# #     def __init__(self, embedding_dim, fully_connected_dim, dropout_rate=0.1, **kwargs):
# #         super(FullyConnectedLayer, self).__init__(**kwargs)
# #         self.embedding_dim = embedding_dim
# #         self.fully_connected_dim = fully_connected_dim
# #         self.dropout_rate = dropout_rate

# #         self.dense1 = Dense(fully_connected_dim, activation='relu')
# #         self.dense2 = Dense(embedding_dim)
# #         self.dropout = Dropout(dropout_rate)

# #     def call(self, x, training=False):
# #         x = self.dense1(x)
# #         x = self.dropout(x, training=training)
# #         return self.dense2(x)

# #     def get_config(self):
# #         config = super().get_config()
# #         config.update({
# #             "embedding_dim": self.embedding_dim,
# #             "fully_connected_dim": self.fully_connected_dim,
# #             "dropout_rate": self.dropout_rate
# #         })
# #         return config

# # class EncoderLayer(tf.keras.layers.Layer):
# #     """Transformer encoder layer with multi-head attention"""

# #     def __init__(self, embedding_dim, num_heads, fully_connected_dim,
# #                  dropout_rate=0.1, **kwargs):
# #         super(EncoderLayer, self).__init__(**kwargs)

# #         self.embedding_dim = embedding_dim
# #         self.num_heads = num_heads
# #         self.fully_connected_dim = fully_connected_dim
# #         self.dropout_rate = dropout_rate

# #         self.mha = MultiHeadAttention(
# #             num_heads=num_heads,
# #             key_dim=embedding_dim,
# #             dropout=dropout_rate
# #         )

# #         self.ffn = FullyConnectedLayer(embedding_dim, fully_connected_dim, dropout_rate)
# #         self.layernorm1 = LayerNormalization(epsilon=1e-6)
# #         self.layernorm2 = LayerNormalization(epsilon=1e-6)
# #         self.dropout1 = Dropout(dropout_rate)

# #     def call(self, inputs, training=False):
# #         # Multi-head attention
# #         attn_output = self.mha(inputs, inputs, inputs, training=training)
# #         attn_output = self.dropout1(attn_output, training=training)
# #         out1 = self.layernorm1(inputs + attn_output)

# #         # Feed forward network
# #         ffn_output = self.ffn(out1, training=training)
# #         encoder_output = self.layernorm2(out1 + ffn_output)

# #         return encoder_output

# #     def get_config(self):
# #         config = super().get_config()
# #         config.update({
# #             "embedding_dim": self.embedding_dim,
# #             "num_heads": self.num_heads,
# #             "fully_connected_dim": self.fully_connected_dim,
# #             "dropout_rate": self.dropout_rate
# #         })
# #         return config

# # class GlobalMeanPoolingLayer(tf.keras.layers.Layer):
# #     """Custom layer for global mean pooling along sequence dimension"""

# #     def __init__(self, **kwargs):
# #         super(GlobalMeanPoolingLayer, self).__init__(**kwargs)

# #     def call(self, inputs):
# #         return tf.reduce_mean(inputs, axis=1)

# #     def get_config(self):
# #         return super().get_config()

# # class DiseaseEmbeddingExpansionLayer(tf.keras.layers.Layer):
# #     """Layer to expand disease embeddings based on batch size"""

# #     def __init__(self, num_classes, **kwargs):
# #         super(DiseaseEmbeddingExpansionLayer, self).__init__(**kwargs)
# #         self.num_classes = num_classes

# #     def call(self, inputs):
# #         # inputs is a tuple of (batch_reference, disease_embeddings)
# #         batch_reference, disease_embeddings = inputs
# #         batch_size = tf.shape(batch_reference)[0]

# #         # Expand disease embeddings to match batch size
# #         expanded = tf.expand_dims(disease_embeddings, 0)
# #         tiled = tf.tile(expanded, [batch_size, 1, 1])
# #         return tiled

# #     def get_config(self):
# #         config = super().get_config()
# #         config.update({"num_classes": self.num_classes})
# #         return config
# #     """Layer to provide BioBERT disease embeddings to the model"""

# #     def __init__(self, disease_embeddings, embedding_dim=512, **kwargs):
# #         super(BioBERTDiseaseEmbeddingLayer, self).__init__(**kwargs)
# #         self.embedding_dim = embedding_dim
# #         self.num_classes = disease_embeddings.shape[0]

# #         # Store disease embeddings as a weight (trainable parameter)
# #         self.disease_embeddings_weight = None
# #         self.disease_embeddings_np = disease_embeddings.astype(np.float32)

# #     def build(self, input_shape):
# #         # Create the embeddings as a trainable weight
# #         self.disease_embeddings_weight = self.add_weight(
# #             name='disease_embeddings',
# #             shape=(self.num_classes, self.embedding_dim),
# #             initializer='zeros',
# #             trainable=False  # Keep as non-trainable since they're pre-computed
# #         )
# #         # Initialize with BioBERT embeddings
# #         self.disease_embeddings_weight.assign(self.disease_embeddings_np)
# #         super().build(input_shape)

# #     def call(self, inputs):
# #         # Get batch size from input tensor
# #         batch_size = tf.shape(inputs)[0]

# #         # Expand disease embeddings to match batch size
# #         # Shape: (num_classes, embedding_dim) -> (batch_size, num_classes, embedding_dim)
# #         expanded_embeddings = tf.expand_dims(self.disease_embeddings_weight, 0)
# #         tiled_embeddings = tf.tile(expanded_embeddings, [batch_size, 1, 1])

# #         return tiled_embeddings

# #     def get_config(self):
# #         config = super().get_config()
# #         config.update({
# #             "embedding_dim": self.embedding_dim,
# #             "num_classes": self.num_classes
# #         })
# #         return config

# # # Model building functions
# # def create_base_model(input_shape=(320, 320, 3)):
# #     """Create base DenseNet201 model"""
# #     base_model = DenseNet201(
# #         include_top=False,
# #         weights='imagenet',
# #         input_shape=input_shape
# #     )
# #     return base_model

# # def create_multi_scale_features(base_model):
# #     """Create multi-scale feature extraction"""

# #     # Get features from different layers
# #     high_level_features = base_model.output  # Shape: (batch, 10, 10, 1920)
# #     low_level_output = base_model.layers[-228].output  # Shape: (batch, 20, 20, 1792)

# #     # Create models for different feature levels
# #     low_level_model = Model(inputs=base_model.input, outputs=low_level_output)

# #     # Multi-Scale Feature Module (MSFM)
# #     f_h = Conv2D(512, (1, 1), activation='relu', name='high_level_conv')(high_level_features)
# #     f_l = Conv2D(512, (1, 1), activation='relu', name='low_level_conv')(low_level_output)

# #     # Upsample high-level features
# #     f_up = Conv2DTranspose(512, kernel_size=(4, 4), strides=(2, 2),
# #                           padding='same', activation='relu', name='upsample')(f_h)

# #     # Combine features
# #     f_combined = Add(name='feature_add')([f_up, f_l])
# #     f_refined = Conv2D(512, 3, padding='same', activation='relu', name='refined_conv')(f_combined)

# #     # Channel Attention Module (CAM)
# #     f_gap = GlobalAveragePooling2D()(f_refined)
# #     f_gap_reshaped = Reshape((1, 1, 512))(f_gap)

# #     f_attention = Conv2D(512, (1, 1), activation='relu')(f_gap_reshaped)
# #     f_attention = Conv2D(512, (1, 1), activation='sigmoid')(f_attention)

# #     f_attended = Multiply()([f_refined, f_attention])
# #     f_final = Add()([f_refined, f_attended])

# #     return f_final, f_h

# # def create_biobert_enhanced_model(input_shape=(320, 320, 3), num_classes=20, disease_embeddings=None):
# #     """Create the complete model with BioBERT disease embeddings"""

# #     # Input layer
# #     image_input = Input(shape=input_shape, name='image_input')

# #     # Base DenseNet201 for visual features
# #     base_model = DenseNet201(
# #         include_top=False,
# #         weights='imagenet',
# #         input_shape=input_shape
# #     )

# #     # Get visual features
# #     visual_features = base_model(image_input)

# #     # Multi-scale feature processing
# #     f_h = Conv2D(512, (1, 1), activation='relu', name='high_level_conv')(visual_features)

# #     # Try to get low-level features (with error handling)
# #     try:
# #         low_level_output = base_model.layers[-228].output
# #         low_level_model = Model(inputs=base_model.input, outputs=low_level_output)
# #         low_level_features = low_level_model(image_input)
# #         f_l = Conv2D(512, (1, 1), activation='relu', name='low_level_conv')(low_level_features)

# #         # Upsample high-level features
# #         f_up = Conv2DTranspose(512, kernel_size=(4, 4), strides=(2, 2),
# #                               padding='same', activation='relu', name='upsample')(f_h)

# #         # Combine features
# #         f_combined = Add(name='feature_add')([f_up, f_l])
# #         f_refined = Conv2D(512, 3, padding='same', activation='relu', name='refined_conv')(f_combined)
# #     except:
# #         # Fallback: just use high-level features
# #         print("Warning: Could not access low-level features, using high-level only")
# #         f_refined = Conv2D(512, 3, padding='same', activation='relu', name='refined_conv')(f_h)

# #     # Global pooling to get visual feature vector
# #     visual_pooled = GlobalAveragePooling2D(name='visual_gap')(f_refined)

# #     # Reshape visual features for transformer: (batch_size, 1, 512)
# #     visual_reshaped = Reshape((1, 512), name='visual_reshape')(visual_pooled)

# #     # BioBERT Disease Embeddings
# #     if disease_embeddings is not None:
# #         # Create a constant layer for disease embeddings
# #         disease_embeddings_const = tf.constant(disease_embeddings.astype(np.float32))

# #         # Create a layer to expand disease embeddings based on batch size
# #         class DiseaseEmbeddingLayer(tf.keras.layers.Layer):
# #             def __init__(self, embeddings, **kwargs):
# #                 super().__init__(**kwargs)
# #                 self.embeddings = embeddings

# #             def call(self, inputs):
# #                 batch_size = tf.shape(inputs)[0]
# #                 expanded = tf.expand_dims(self.embeddings, 0)
# #                 return tf.tile(expanded, [batch_size, 1, 1])

# #         disease_embedding_layer = DiseaseEmbeddingLayer(disease_embeddings_const, name='disease_embeddings')
# #         disease_features = disease_embedding_layer(image_input)
# #     else:
# #         # Fallback to learnable embeddings
# #         embedding_layer = tf.keras.layers.Embedding(num_classes, 512, name='learnable_disease_embeddings')
# #         indices = tf.range(num_classes)

# #         # Create a layer to handle the embedding expansion
# #         class LearnableEmbeddingLayer(tf.keras.layers.Layer):
# #             def __init__(self, embedding_layer, indices, **kwargs):
# #                 super().__init__(**kwargs)
# #                 self.embedding_layer = embedding_layer
# #                 self.indices = indices

# #             def call(self, inputs):
# #                 batch_size = tf.shape(inputs)[0]
# #                 embeddings = self.embedding_layer(self.indices)
# #                 expanded = tf.expand_dims(embeddings, 0)
# #                 return tf.tile(expanded, [batch_size, 1, 1])

# #         learnable_layer = LearnableEmbeddingLayer(embedding_layer, indices, name='learnable_expansion')
# #         disease_features = learnable_layer(image_input)

# #     # Concatenate visual and disease features
# #     combined_features = Concatenate(axis=1, name='feature_concat')([visual_reshaped, disease_features])

# #     # Transformer encoder layers
# #     transformer1 = EncoderLayer(
# #         embedding_dim=512,
# #         num_heads=8,
# #         fully_connected_dim=2048,
# #         dropout_rate=0.1,
# #         name='transformer_encoder_1'
# #     )
# #     encoded_features1 = transformer1(combined_features)

# #     # Second transformer layer
# #     transformer2 = EncoderLayer(
# #         embedding_dim=512,
# #         num_heads=8,
# #         fully_connected_dim=2048,
# #         dropout_rate=0.1,
# #         name='transformer_encoder_2'
# #     )
# #     encoded_features2 = transformer2(encoded_features1)

# #     # Global pooling of all tokens using custom layer
# #     global_pool = GlobalMeanPoolingLayer(name='global_mean_pool')
# #     final_features = global_pool(encoded_features2)

# #     # Classification head
# #     x = Dense(1024, activation='relu', name='classifier_dense1')(final_features)
# #     x = Dropout(0.5, name='classifier_dropout1')(x)
# #     x = Dense(512, activation='relu', name='classifier_dense2')(x)
# #     x = Dropout(0.3, name='classifier_dropout2')(x)
# #     x = Dense(256, activation='relu', name='classifier_dense3')(x)
# #     x = Dropout(0.2, name='classifier_dropout3')(x)

# #     # Final predictions
# #     predictions = Dense(num_classes, activation='sigmoid', name='predictions')(x)

# #     # Create final model
# #     model = Model(inputs=image_input, outputs=predictions, name='BioBERT_Medical_Classifier')

# #     return model

# # def create_complete_model(input_shape=(320, 320, 3), num_classes=20, disease_embeddings=None):
# #     """Create the complete model with BioBERT integration"""
# #     return create_biobert_enhanced_model(input_shape, num_classes, disease_embeddings)

# # def compile_model(model, learning_rate=0.001):
# #     """Compile the model with appropriate optimizer and metrics"""

# #     model.compile(
# #         optimizer=Adam(learning_rate=learning_rate),
# #         loss='binary_crossentropy',
# #         metrics=[
# #             'accuracy',
# #             Precision(name='precision'),
# #             Recall(name='recall'),
# #             AUC(name='auc')
# #         ]
# #     )

# #     return model

# # def create_callbacks(model_save_path='best_biobert_model.h5'):
# #     """Create training callbacks"""

# #     callbacks = [
# #         EarlyStopping(
# #             monitor='val_loss',
# #             patience=7,
# #             restore_best_weights=True,
# #             verbose=1
# #         ),
# #         ModelCheckpoint(
# #             model_save_path,
# #             monitor='val_auc',
# #             save_best_only=True,
# #             mode='max',
# #             verbose=1
# #         )
# #     ]

# #     return callbacks

# # def train_model(model, train_gen, val_gen, epochs=50, callbacks=None):
# #     """Train the model"""

# #     history = model.fit(
# #         train_gen,
# #         steps_per_epoch=len(train_gen),
# #         epochs=epochs,
# #         validation_data=val_gen,
# #         validation_steps=len(val_gen),
# #         callbacks=callbacks,
# #         verbose=1
# #     )

# #     return history

# # def evaluate_model(model, test_gen):
# #     """Evaluate the model on test data"""

# #     test_loss, test_accuracy, test_precision, test_recall, test_auc = model.evaluate(
# #         test_gen,
# #         steps=len(test_gen),
# #         verbose=1
# #     )

# #     print(f"\nTest Results:")
# #     print(f"Test Loss: {test_loss:.4f}")
# #     print(f"Test Accuracy: {test_accuracy:.4f}")
# #     print(f"Test Precision: {test_precision:.4f}")
# #     print(f"Test Recall: {test_recall:.4f}")
# #     print(f"Test AUC: {test_auc:.4f}")

# #     return {
# #         'test_loss': test_loss,
# #         'test_accuracy': test_accuracy,
# #         'test_precision': test_precision,
# #         'test_recall': test_recall,
# #         'test_auc': test_auc
# #     }

# # # Main execution function
# # def main():
# #     """Main function to run the complete pipeline with BioBERT"""

# #     print("Starting Medical Image Classification Pipeline with BioBERT...")

# #     # Extract data
# #     extract_data()

# #     # Load and prepare data
# #     train_data, test_data = load_and_prepare_data()
# #     if train_data is None or test_data is None:
# #         print("Failed to load data. Exiting.")
# #         return

# #     # Generate BioBERT disease embeddings
# #     print("Generating BioBERT disease embeddings...")
# #     biobert_embedder = BioBERTDiseaseEmbedder(embedding_dim=512)
# #     disease_embeddings = biobert_embedder.generate_disease_embeddings(DISEASE_LABELS_FULL)

# #     # Create data generators
# #     train_gen, val_gen, test_gen = create_data_generators(train_data, test_data)

# #     print("Data generators created successfully")
# #     print(f"Training samples: {train_gen.n}")
# #     print(f"Validation samples: {val_gen.n}")
# #     print(f"Test samples: {test_gen.n}")

# #     # Create and compile model with BioBERT embeddings
# #     print("Creating BioBERT-enhanced model...")
# #     model = create_complete_model(
# #         input_shape=(320, 320, 3),
# #         num_classes=20,
# #         disease_embeddings=disease_embeddings
# #     )
# #     model = compile_model(model, learning_rate=0.0001)  # Lower learning rate for stability

# #     print("Model created and compiled successfully")
# #     print(f"Model parameters: {model.count_params():,}")

# #     # Print model summary
# #     model.summary()

# #     # Create callbacks
# #     callbacks = create_callbacks('best_biobert_medical_model.h5')

# #     # Train model
# #     print("Starting training...")
# #     history = train_model(
# #         model,
# #         train_gen,
# #         val_gen,
# #         epochs=30,
# #         callbacks=callbacks
# #     )

# #     # Evaluate model
# #     print("Evaluating model...")
# #     test_results = evaluate_model(model, test_gen)

# #     print("Training completed successfully!")
# #     print("\nArchitecture Summary:")
# #     print("1. Visual Features: DenseNet201 + Multi-scale features → 512-dim vectors")
# #     print("2. BioBERT Disease Embeddings → 512-dim vectors")
# #     print("3. Combined features fed to Transformer encoders")
# #     print("4. Final classification predictions")

# #     return model, history, test_results, disease_embeddings

# # # Run the pipeline
# # if __name__ == "__main__":
# #     model, history, results, embeddings = main()

# import numpy as np
# import pandas as pd
# import tensorflow as tf
# from tensorflow.keras.preprocessing import image
# from tensorflow.keras.models import load_model
# import matplotlib.pyplot as plt
# import matplotlib.patches as patches
# from google.colab import files
# from PIL import Image, ImageDraw, ImageFont
# import io
# import cv2

# # Disease labels (same as in training)
# DISEASE_LABELS = ['DR', 'NORMAL', 'MH', 'ODC', 'TSLN', 'ARMD', 'DN', 'MYA',
#                   'BRVO', 'ODP', 'CRVO', 'CNV', 'RS', 'ODE', 'LS', 'CSR',
#                   'HTR', 'ASR', 'CRS', 'OTHER']

# DISEASE_LABELS_FULL = ['DIABETIC RETINOPATHY', 'NORMAL', 'MEDIA HAZE',
#                        'OPTIC DISC COLOBOMA', 'TESSELLATION',
#                        'AGE RELATED MACULAR DEGENERATION', 'DRUSEN', 'MYOPIA',
#                        'BRANCH RETINAL VEIN OCCLUSION', 'OPTIC DISC PALLOR',
#                        'CENTRAL RETINAL VEIN OCCLUSION', 'CHOROIDAL NEOVASCULARIZATION',
#                        'RETINITIS', 'OPTIC DISC EDEMA', 'LASER SCARS',
#                        'CENTRAL SEROUS RETINOPATHY', 'HYPERTENSIVE RETINOPATHY',
#                        'ARTIFICIAL SILICON RETINA', 'CHORIORETINITIS', 'OTHER']

# # Define color mapping for different disease severity levels
# def get_color_for_probability(prob):
#     """Return color based on probability threshold"""
#     if prob >= 0.8:
#         return '#FF4444'  # High probability - Red
#     elif prob >= 0.6:
#         return '#FF8800'  # Medium-high probability - Orange
#     elif prob >= 0.4:
#         return '#FFCC00'  # Medium probability - Yellow
#     elif prob >= 0.2:
#         return '#88CC00'  # Low-medium probability - Light green
#     else:
#         return '#44AA44'  # Low probability - Green

# def load_trained_model():
#     """Load the trained BioBERT model"""
#     try:
#         print("Loading trained BioBERT model...")

#         # Custom objects for loading the model
#         custom_objects = {
#             'FullyConnectedLayer': FullyConnectedLayer,
#             'EncoderLayer': EncoderLayer,
#             'GlobalMeanPoolingLayer': GlobalMeanPoolingLayer,
#             'DiseaseEmbeddingExpansionLayer': DiseaseEmbeddingExpansionLayer
#         }

#         model = load_model('/content/best_biobert_medical_model.h5',
#                           custom_objects=custom_objects,
#                           compile=False)
#         print("✅ Model loaded successfully!")
#         return model
#     except Exception as e:
#         print(f"❌ Error loading model: {e}")
#         print("Please ensure the model file exists at: /content/best_biobert_medical_model.h5")
#         return None

# def preprocess_image(img_path, target_size=(320, 320)):
#     """Preprocess uploaded image for prediction"""
#     try:
#         # Load image
#         img = image.load_img(img_path, target_size=target_size)

#         # Convert to array and normalize
#         img_array = image.img_to_array(img)
#         img_array = np.expand_dims(img_array, axis=0)
#         img_array = img_array / 255.0  # Normalize to [0,1]

#         return img_array, img
#     except Exception as e:
#         print(f"❌ Error preprocessing image: {e}")
#         return None, None

# def create_prediction_visualization(original_img, predictions, probabilities, threshold=0.3):
#     """Create a comprehensive visualization of predictions"""

#     # Create figure with subplots
#     fig, axes = plt.subplots(2, 2, figsize=(16, 12))
#     fig.suptitle('BioBERT Medical Image Classification Results', fontsize=16, fontweight='bold')

#     # 1. Original Image
#     axes[0, 0].imshow(original_img)
#     axes[0, 0].set_title('Original Retinal Image', fontsize=14, fontweight='bold')
#     axes[0, 0].axis('off')

#     # 2. Predictions above threshold
#     significant_preds = [(label, full_name, prob) for label, full_name, prob in
#                         zip(DISEASE_LABELS, DISEASE_LABELS_FULL, probabilities)
#                         if prob >= threshold]

#     if significant_preds:
#         # Sort by probability
#         significant_preds.sort(key=lambda x: x[2], reverse=True)

#         y_pos = np.arange(len(significant_preds))
#         probs = [pred[2] for pred in significant_preds]
#         labels = [f"{pred[0]}\n({pred[1][:25]}...)" if len(pred[1]) > 25
#                  else f"{pred[0]}\n({pred[1]})" for pred in significant_preds]
#         colors = [get_color_for_probability(prob) for prob in probs]

#         bars = axes[0, 1].barh(y_pos, probs, color=colors, alpha=0.8)
#         axes[0, 1].set_yticks(y_pos)
#         axes[0, 1].set_yticklabels(labels, fontsize=10)
#         axes[0, 1].set_xlabel('Probability', fontsize=12)
#         axes[0, 1].set_title(f'Detected Conditions (≥{threshold})', fontsize=14, fontweight='bold')
#         axes[0, 1].set_xlim(0, 1)

#         # Add probability values on bars
#         for bar, prob in zip(bars, probs):
#             axes[0, 1].text(bar.get_width() + 0.01, bar.get_y() + bar.get_height()/2,
#                            f'{prob:.3f}', va='center', fontsize=10, fontweight='bold')
#     else:
#         axes[0, 1].text(0.5, 0.5, 'No significant conditions detected\n(all probabilities < threshold)',
#                        ha='center', va='center', fontsize=12, transform=axes[0, 1].transAxes)
#         axes[0, 1].set_title(f'Detected Conditions (≥{threshold})', fontsize=14, fontweight='bold')

#     # 3. All predictions heatmap
#     prob_matrix = probabilities.reshape(4, 5)  # 4x5 grid for 20 diseases
#     label_matrix = np.array(DISEASE_LABELS).reshape(4, 5)

#     im = axes[1, 0].imshow(prob_matrix, cmap='RdYlGn_r', aspect='auto', vmin=0, vmax=1)
#     axes[1, 0].set_title('All Conditions Probability Heatmap', fontsize=14, fontweight='bold')

#     # Add labels to heatmap
#     for i in range(4):
#         for j in range(5):
#             text = axes[1, 0].text(j, i, f'{label_matrix[i, j]}\n{prob_matrix[i, j]:.3f}',
#                                   ha="center", va="center", fontsize=9, fontweight='bold')

#     axes[1, 0].set_xticks([])
#     axes[1, 0].set_yticks([])

#     # Add colorbar
#     plt.colorbar(im, ax=axes[1, 0], fraction=0.046, pad=0.04)

#     # 4. Risk Assessment Summary
#     axes[1, 1].axis('off')

#     # Calculate risk levels
#     high_risk = sum(1 for p in probabilities if p >= 0.7)
#     medium_risk = sum(1 for p in probabilities if 0.4 <= p < 0.7)
#     low_risk = sum(1 for p in probabilities if 0.2 <= p < 0.4)

#     # Find top condition
#     max_idx = np.argmax(probabilities)
#     top_condition = DISEASE_LABELS_FULL[max_idx]
#     top_prob = probabilities[max_idx]

#     summary_text = f"""
# CLINICAL ASSESSMENT SUMMARY

# 🔴 High Risk Conditions (≥70%): {high_risk}
# 🟡 Medium Risk Conditions (40-69%): {medium_risk}
# 🟢 Low Risk Conditions (20-39%): {low_risk}

# 📊 TOP FINDING:
# {top_condition}
# Confidence: {top_prob:.1%}

# ⚠️  CLINICAL NOTES:
# • This is an AI-generated assessment
# • Requires professional medical evaluation
# • Not a substitute for clinical diagnosis
# • Consider patient history and symptoms

# 📋 RECOMMENDATION:
# {'Immediate medical attention recommended' if top_prob > 0.8
#  else 'Medical consultation advised' if top_prob > 0.5
#  else 'Routine follow-up suggested'}
#     """

#     axes[1, 1].text(0.05, 0.95, summary_text, transform=axes[1, 1].transAxes,
#                    fontsize=11, verticalalignment='top', fontfamily='monospace',
#                    bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.8))

#     plt.tight_layout()
#     plt.show()

# def generate_detailed_report(probabilities, threshold=0.2):
#     """Generate a detailed medical report"""
#     print("\n" + "="*80)
#     print("🏥 DETAILED MEDICAL IMAGE ANALYSIS REPORT")
#     print("="*80)

#     # Sort conditions by probability
#     condition_data = list(zip(DISEASE_LABELS, DISEASE_LABELS_FULL, probabilities))
#     condition_data.sort(key=lambda x: x[2], reverse=True)

#     print(f"\n📊 FINDINGS SUMMARY:")
#     print(f"{'Rank':<4} {'Code':<6} {'Condition':<35} {'Confidence':<12} {'Status'}")
#     print("-" * 80)

#     for rank, (code, full_name, prob) in enumerate(condition_data, 1):
#         if prob >= threshold:
#             status = "🔴 HIGH" if prob >= 0.7 else "🟡 MEDIUM" if prob >= 0.4 else "🟢 LOW"
#             print(f"{rank:<4} {code:<6} {full_name:<35} {prob:.1%}{'':>8} {status}")

#     # Clinical interpretation
#     print(f"\n🩺 CLINICAL INTERPRETATION:")
#     top_condition = condition_data[0]

#     if top_condition[2] >= 0.8:
#         interpretation = f"Strong indication of {top_condition[1]} detected."
#     elif top_condition[2] >= 0.6:
#         interpretation = f"Moderate indication of {top_condition[1]} observed."
#     elif top_condition[2] >= 0.4:
#         interpretation = f"Mild indication of {top_condition[1]} present."
#     else:
#         interpretation = "No significant pathological findings detected."

#     print(f"• {interpretation}")

#     # Recommendations
#     print(f"\n💡 RECOMMENDATIONS:")
#     if top_condition[2] >= 0.7:
#         print("• Immediate ophthalmological consultation recommended")
#         print("• Consider advanced imaging (OCT, fluorescein angiography)")
#         print("• Monitor for progression and complications")
#     elif top_condition[2] >= 0.4:
#         print("• Schedule routine ophthalmological follow-up")
#         print("• Monitor symptoms and visual changes")
#         print("• Consider lifestyle modifications if applicable")
#     else:
#         print("• Continue routine eye care and regular check-ups")
#         print("• Maintain healthy lifestyle habits")

#     print(f"\n⚠️  IMPORTANT DISCLAIMERS:")
#     print("• This AI analysis is for research/educational purposes only")
#     print("• Not intended for clinical diagnosis or treatment decisions")
#     print("• Professional medical evaluation is always required")
#     print("• Consider patient history, symptoms, and clinical context")

#     print("="*80)

# def predict_single_image():
#     """Main function for single image prediction"""

#     # Load the trained model
#     model = load_trained_model()
#     if model is None:
#         return

#     print("\n🔬 BioBERT Medical Image Classifier")
#     print("=" * 50)
#     print("Upload a retinal fundus image for analysis...")

#     # Upload image
#     uploaded = files.upload()

#     if not uploaded:
#         print("❌ No image uploaded.")
#         return

#     # Process each uploaded image
#     for filename, data in uploaded.items():
#         print(f"\n📸 Processing: {filename}")
#         print("-" * 30)

#         # Save uploaded file temporarily
#         with open(filename, 'wb') as f:
#             f.write(data)

#         # Preprocess image
#         img_array, original_img = preprocess_image(filename)

#         if img_array is None:
#             continue

#         # Make prediction
#         print("🤖 Running BioBERT analysis...")
#         try:
#             predictions = model.predict(img_array, verbose=0)
#             probabilities = predictions[0]  # Get first (and only) sample

#             print("✅ Analysis complete!")

#             # Create visualization
#             create_prediction_visualization(original_img, predictions, probabilities)

#             # Generate detailed report
#             generate_detailed_report(probabilities)

#         except Exception as e:
#             print(f"❌ Error during prediction: {e}")
#             print("Please check that the model architecture matches the saved model.")

# def predict_batch_images():
#     """Function for batch image prediction"""

#     # Load the trained model
#     model = load_trained_model()
#     if model is None:
#         return

#     print("\n🔬 BioBERT Medical Image Classifier - Batch Mode")
#     print("=" * 60)
#     print("Upload multiple retinal fundus images for batch analysis...")

#     # Upload images
#     uploaded = files.upload()

#     if not uploaded:
#         print("❌ No images uploaded.")
#         return

#     results_summary = []

#     # Process each uploaded image
#     for idx, (filename, data) in enumerate(uploaded.items(), 1):
#         print(f"\n📸 Processing {idx}/{len(uploaded)}: {filename}")

#         # Save uploaded file temporarily
#         with open(filename, 'wb') as f:
#             f.write(data)

#         # Preprocess image
#         img_array, original_img = preprocess_image(filename)

#         if img_array is None:
#             continue

#         # Make prediction
#         try:
#             predictions = model.predict(img_array, verbose=0)
#             probabilities = predictions[0]

#             # Store results
#             top_idx = np.argmax(probabilities)
#             top_condition = DISEASE_LABELS_FULL[top_idx]
#             top_prob = probabilities[top_idx]

#             results_summary.append({
#                 'filename': filename,
#                 'top_condition': top_condition,
#                 'confidence': top_prob,
#                 'probabilities': probabilities
#             })

#             print(f"✅ Top finding: {top_condition} ({top_prob:.1%})")

#         except Exception as e:
#             print(f"❌ Error processing {filename}: {e}")

#     # Display batch summary
#     if results_summary:
#         print("\n" + "="*80)
#         print("📊 BATCH ANALYSIS SUMMARY")
#         print("="*80)

#         print(f"{'Image':<25} {'Top Condition':<35} {'Confidence'}")
#         print("-" * 80)

#         for result in results_summary:
#             print(f"{result['filename']:<25} {result['top_condition']:<35} {result['confidence']:.1%}")

#         # Create batch visualization
#         fig, axes = plt.subplots(2, min(3, len(results_summary)),
#                                 figsize=(5*min(3, len(results_summary)), 10))
#         if len(results_summary) == 1:
#             axes = axes.reshape(-1, 1)

#         for i, result in enumerate(results_summary[:3]):  # Show first 3
#             if len(results_summary) > 1:
#                 col = i
#             else:
#                 col = 0

#             # Load and display image
#             img = image.load_img(result['filename'], target_size=(320, 320))
#             axes[0, col].imshow(img)
#             axes[0, col].set_title(f"{result['filename'][:15]}...", fontsize=10)
#             axes[0, col].axis('off')

#             # Show top predictions
#             top_5_idx = np.argsort(result['probabilities'])[-5:][::-1]
#             top_5_probs = result['probabilities'][top_5_idx]
#             top_5_labels = [DISEASE_LABELS[idx] for idx in top_5_idx]

#             axes[1, col].barh(range(5), top_5_probs,
#                              color=[get_color_for_probability(p) for p in top_5_probs])
#             axes[1, col].set_yticks(range(5))
#             axes[1, col].set_yticklabels(top_5_labels)
#             axes[1, col].set_xlabel('Probability')
#             axes[1, col].set_title('Top 5 Predictions', fontsize=10)

#         plt.tight_layout()
#         plt.show()

# # Main interface
# def medical_image_classifier():
#     """Main interface for the medical image classifier"""

#     print("🏥 BioBERT-Enhanced Medical Image Classification System")
#     print("=" * 60)
#     print("Choose analysis mode:")
#     print("1. Single Image Analysis (detailed)")
#     print("2. Batch Image Analysis (multiple images)")

#     try:
#         choice = input("\nEnter your choice (1 or 2): ").strip()

#         if choice == "1":
#             predict_single_image()
#         elif choice == "2":
#             predict_batch_images()
#         else:
#             print("❌ Invalid choice. Please enter 1 or 2.")

#     except KeyboardInterrupt:
#         print("\n⏹️ Analysis cancelled by user.")
#     except Exception as e:
#         print(f"❌ Unexpected error: {e}")

# # Run the classifier
# print("🚀 Ready to analyze retinal images!")
# print("Run medical_image_classifier() to start the analysis")

# # Uncomment the line below to run immediately
# medical_image_classifier()