In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import zipfile
import os

# Define the zip file path and extraction path
zip_path = "/content/drive/My Drive/brain_tumor_dataset/binary_class.zip"
extract_path = "/content/binary_class"

# Unzip the file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)


In [None]:
# List extracted files/folders
os.listdir(extract_path)


['Testing', 'Training']

In [None]:
import os
os.path.exists("/content/binary_class")  # ➡️ True hona chahiye


True

In [None]:
# Write dependencies to requirements.txt
with open("requirements.txt", "w") as f:
    f.write("""\
numpy
pandas
matplotlib
seaborn
scikit-learn
opencv-python
pillow

torch
torchvision
timm

catboost
optuna

shap

reportlab

streamlit
gradio

tqdm
pathlib
""")


In [None]:
!pip install -r requirements.txt




In [None]:
from reportlab.lib.colors import Color

# If you originally had:
# RGBColor(255, 0, 0)

# Replace with:
red_color = Color(1.0, 0.0, 0.0)


In [None]:
import sys
import argparse

# Only keep arguments you want (ignore Colab/Jupyter args)
sys.argv = ['script.py']  # Reset to only your script name

parser = argparse.ArgumentParser()
parser.add_argument('--mode', choices=['train', 'predict', 'demo', 'streamlit', 'gradio'], required=False, default='train')
parser.add_argument('--training_data', type=str, default=None)
parser.add_argument('--image', type=str, default=None)
parser.add_argument('--output_dir', type=str, default=None)
args = parser.parse_args()


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.decomposition import PCA
from sklearn.metrics import (accuracy_score, classification_report, confusion_matrix,
                            precision_score, recall_score, f1_score)
import cv2
import os
import warnings
warnings.filterwarnings('ignore')

# Deep Learning imports
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import timm

# CatBoost and Optuna imports
from catboost import CatBoostClassifier
import optuna
from optuna.samplers import TPESampler

# Explainable AI imports
import shap

# Report generation imports
from reportlab.lib.pagesizes import letter
from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image as ReportImage, Table, TableStyle
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.units import inch
from reportlab.lib.colors import Color
import matplotlib.pyplot as plt
from io import BytesIO
import base64
from datetime import datetime

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

class MRIPreprocessor:
    """Enhanced MRI image preprocessing with CLAHE"""

    def __init__(self, target_size=(224, 224)):
        self.target_size = target_size
        self.clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))

    def apply_clahe(self, image):
        """Apply Contrast Limited Adaptive Histogram Equalization"""
        if len(image.shape) == 3:
            lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
            lab[:, :, 0] = self.clahe.apply(lab[:, :, 0])
            image = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
        else:
            image = self.clahe.apply(image)
        return image

    def preprocess_image(self, image_path):
        """Complete preprocessing pipeline"""
        if isinstance(image_path, str):
            image = cv2.imread(str(image_path))
        else:
            # Handle PIL Image or numpy array
            if hasattr(image_path, 'convert'):
                image = np.array(image_path.convert('RGB'))
            else:
                image = image_path

        if image is None:
            raise ValueError(f"Could not load image: {image_path}")

        if len(image.shape) == 3 and image.shape[2] == 3:
            if np.max(image) <= 1.0:
                image = (image * 255).astype(np.uint8)

        # Apply CLAHE
        image = self.apply_clahe(image)

        # Resize to target size
        image = cv2.resize(image, self.target_size)
        image = image.astype(np.float32) / 255.0
        return image

    def batch_preprocess(self, image_paths, labels=None):
        """Preprocess multiple images with progress tracking"""
        processed_images = []
        valid_labels = []
        total = len(image_paths)

        print(f"Processing {total} images...")

        for i, path in enumerate(image_paths):
            try:
                img = self.preprocess_image(path)
                processed_images.append(img)
                if labels is not None:
                    valid_labels.append(labels[i])

                # Progress indicator
                if (i + 1) % 100 == 0 or (i + 1) == total:
                    print(f"Processed {i + 1}/{total} images ({(i + 1)/total*100:.1f}%)")

            except Exception as e:
                print(f"Error processing {path}: {e}")
                continue

        return np.array(processed_images), (np.array(valid_labels) if labels is not None else None)

class ViTFeatureExtractor:
    """Vision Transformer (ViT-B16) feature extraction"""

    def __init__(self, device='cuda', output_dim=64):
        self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")

        # Load pre-trained ViT-B16
        self.model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
        self.model.to(self.device)
        self.model.eval()

        # Add a projection layer to get desired output dimension
        self.projection = nn.Linear(self.model.num_features, output_dim).to(self.device)
        self.projection.eval()

        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def extract_features(self, images, batch_size=16):
        """Extract 64-dimensional features using ViT-B16"""
        features = []
        total_batches = (len(images) + batch_size - 1) // batch_size

        print(f"Extracting ViT features in {total_batches} batches...")

        with torch.no_grad():
            for batch_idx in range(0, len(images), batch_size):
                batch_images = images[batch_idx:batch_idx + batch_size]
                batch_tensors = []

                for image in batch_images:
                    if image.max() <= 1.0:
                        image = (image * 255).astype(np.uint8)
                    input_tensor = self.transform(image)
                    batch_tensors.append(input_tensor)

                batch_tensor = torch.stack(batch_tensors).to(self.device)

                # Extract features from ViT
                vit_features = self.model(batch_tensor)

                # Project to desired dimension
                projected_features = self.projection(vit_features)

                features.append(projected_features.cpu().numpy())

                # Progress indicator
                current_batch = batch_idx // batch_size + 1
                print(f"Processed batch {current_batch}/{total_batches}")

        return np.vstack(features)

class OptunaCatBoostClassifier:
    """CatBoost classifier with Optuna hyperparameter optimization"""

    def __init__(self, n_trials=50):
        self.n_trials = n_trials
        self.best_model = None
        self.best_params = None
        self.study = None

    def objective(self, trial, X_train, y_train, X_val, y_val):
        """Optuna objective function for hyperparameter tuning"""
        params = {
            'iterations': trial.suggest_int('iterations', 100, 1000),
            'depth': trial.suggest_int('depth', 4, 10),
            'learning_rate': trial.suggest_float('learning_rate', 0.01, 0.3),
            'l2_leaf_reg': trial.suggest_float('l2_leaf_reg', 1, 10),
            'border_count': trial.suggest_int('border_count', 32, 255),
            'random_seed': 42,
            'verbose': False
        }

        model = CatBoostClassifier(**params)
        model.fit(X_train, y_train, eval_set=(X_val, y_val), verbose=False)

        y_pred = model.predict(X_val)
        accuracy = accuracy_score(y_val, y_pred)

        return accuracy

    def optimize_and_train(self, X, y):
        """Optimize hyperparameters and train the best model"""
        print("Starting Optuna hyperparameter optimization...")

        # Split for validation
        X_train, X_val, y_train, y_val = train_test_split(
            X, y, test_size=0.2, random_state=42, stratify=y
        )

        # Create study
        self.study = optuna.create_study(
            direction='maximize',
            sampler=TPESampler(seed=42)
        )

        # Optimize
        self.study.optimize(
            lambda trial: self.objective(trial, X_train, y_train, X_val, y_val),
            n_trials=self.n_trials,
            show_progress_bar=True
        )

        # Get best parameters
        self.best_params = self.study.best_params
        print(f"Best parameters: {self.best_params}")
        print(f"Best validation accuracy: {self.study.best_value:.4f}")

        # Train final model with best parameters
        self.best_model = CatBoostClassifier(**self.best_params, verbose=False)

        # Use full training data for final model
        self.best_model.fit(X, y, verbose=False)

        return self.best_model

    def predict(self, X):
        if self.best_model is None:
            raise ValueError("Model not trained yet!")
        return self.best_model.predict(X)

    def predict_proba(self, X):
        if self.best_model is None:
            raise ValueError("Model not trained yet!")
        return self.best_model.predict_proba(X)

    def get_feature_importance(self):
        if self.best_model is None:
            raise ValueError("Model not trained yet!")
        return self.best_model.get_feature_importance()

class SHAPExplainer:
    """SHAP-based explainable AI for model interpretation"""

    def __init__(self, model, X_train_sample):
        self.model = model
        # Use a smaller sample for SHAP background
        self.background = X_train_sample[:min(100, len(X_train_sample))]
        self.explainer = shap.Explainer(self.model.predict, self.background)

    def explain_prediction(self, X_instance, feature_names=None):
        """Generate SHAP explanation for a single prediction"""
        if len(X_instance.shape) == 1:
            X_instance = X_instance.reshape(1, -1)

        shap_values = self.explainer(X_instance)

        if feature_names is None:
            feature_names = [f'PC{i+1}' for i in range(X_instance.shape[1])]

        # Get SHAP values for the prediction
        if len(shap_values.shape) == 3:  # Multi-class
            shap_vals = shap_values.values[0]  # First instance
        else:
            shap_vals = shap_values.values[0]

        # Get top influential features
        if len(shap_vals.shape) == 2:  # Multi-class
            # Sum absolute SHAP values across classes
            importance = np.abs(shap_vals).sum(axis=1)
        else:
            importance = np.abs(shap_vals)

        top_indices = np.argsort(importance)[::-1][:5]
        top_features = [(feature_names[i], importance[i]) for i in top_indices]

        return {
            'shap_values': shap_values,
            'top_features': top_features,
            'feature_importance': importance
        }

    def plot_explanation(self, shap_values, save_path=None):
        """Create SHAP explanation plots"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

        # Waterfall plot
        if len(shap_values.shape) == 3:
            shap.plots.waterfall(shap_values[0], show=False, max_display=10)
        else:
            shap.plots.waterfall(shap_values[0], show=False, max_display=10)

        plt.sca(ax1)
        plt.title("SHAP Waterfall Plot")

        # Summary plot
        plt.sca(ax2)
        if len(shap_values.values.shape) == 3:
            shap.plots.bar(shap_values[0], show=False, max_display=10)
        else:
            shap.plots.bar(shap_values[0], show=False, max_display=10)
        plt.title("Feature Importance")

        plt.tight_layout()

        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')

        return fig

class ReportGenerator:
    """Generate comprehensive PDF and Markdown reports"""

    def __init__(self):
        self.styles = getSampleStyleSheet()
        self.custom_style = ParagraphStyle(
            'CustomTitle',
            parent=self.styles['Heading1'],
            fontSize=16,
            textColor=Color(0.2, 0.4, 0.6),
            spaceAfter=12
        )

    def generate_markdown_report(self, prediction_result, save_path="brain_tumor_report.md"):
        """Generate detailed markdown report"""

        report_content = f"""# 🧠 Brain Tumor Detection Report

**Generated on:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}

## 📊 Prediction Results

🔍 **Prediction:** {prediction_result['predicted_class']}
📈 **Confidence Score:** {prediction_result['confidence']:.2f}%
🧠 **Top Influencing Features:** {', '.join([f[0] for f in prediction_result['top_features'][:3]])}

## 🩺 Medical Insight

{prediction_result['medical_insight']}

## 📋 Detailed Analysis

### Preprocessing Steps:
1. ✅ **CLAHE Applied:** Contrast-Limited Adaptive Histogram Equalization for enhanced image quality
2. ✅ **Image Resized:** Standardized to 224x224 pixels for ViT compatibility
3. ✅ **Normalization:** Pixel values normalized to [0,1] range

### Feature Extraction:
- **Model Used:** Vision Transformer (ViT-B16) pre-trained
- **Feature Dimension:** 64-dimensional deep features extracted
- **PCA Reduction:** Dimensionality reduced to 45 components
- **Explained Variance:** {prediction_result.get('explained_variance', 'N/A'):.3f}

### Classification:
- **Algorithm:** CatBoost Classifier
- **Hyperparameter Optimization:** Optuna-tuned parameters
- **Best Parameters:** {prediction_result.get('best_params', 'N/A')}
- **Model Accuracy:** {prediction_result.get('model_accuracy', 'N/A'):.2f}%

## 🎯 Class Probabilities

| Class | Probability | Confidence |
|-------|-------------|------------|
"""

        for class_name, prob in prediction_result['class_probabilities'].items():
            report_content += f"| {class_name} | {prob:.4f} | {prob*100:.2f}% |\n"

        report_content += f"""

## 🔍 SHAP Feature Analysis

### Top 5 Most Influential Features:

"""

        for i, (feature, importance) in enumerate(prediction_result['top_features'][:5], 1):
            report_content += f"{i}. **{feature}:** Impact score {importance:.4f}\n"

        report_content += f"""

### Feature Interpretation:
{prediction_result.get('feature_interpretation', 'Feature analysis completed using SHAP values to identify key contributing factors.')}

## 📈 Model Performance Summary

- **Training Samples:** {prediction_result.get('training_samples', 'N/A')}
- **Feature Extraction Time:** {prediction_result.get('extraction_time', 'N/A')}
- **Prediction Time:** {prediction_result.get('prediction_time', 'N/A')}
- **Cross-validation Score:** {prediction_result.get('cv_score', 'N/A')}

## ⚠️ Important Notes

1. This prediction is based on automated analysis and should be verified by medical professionals
2. The model was trained on publicly available datasets and may not cover all tumor variations
3. SHAP explanations provide insights into feature importance but require clinical interpretation
4. For critical medical decisions, always consult with qualified radiologists and oncologists

---
*Report generated by AI-powered Brain Tumor Classification System*
"""

        with open(save_path, 'w', encoding='utf-8') as f:
            f.write(report_content)

        print(f"✅ Markdown report saved to: {save_path}")
        return save_path

    def generate_pdf_report(self, prediction_result, save_path="brain_tumor_report.pdf"):
        """Generate comprehensive PDF report"""

        doc = SimpleDocTemplate(save_path, pagesize=letter)
        story = []

        # Title
        title = Paragraph("🧠 Brain Tumor Detection Report", self.custom_style)
        story.append(title)
        story.append(Spacer(1, 12))

        # Prediction Results
        pred_text = f"""
        <b>🔍 Prediction:</b> {prediction_result['predicted_class']}<br/>
        <b>📈 Confidence Score:</b> {prediction_result['confidence']:.2f}%<br/>
        <b>🧠 Top Features:</b> {', '.join([f[0] for f in prediction_result['top_features'][:3]])}<br/>
        """
        story.append(Paragraph(pred_text, self.styles['Normal']))
        story.append(Spacer(1, 12))

        # Medical Insight
        insight_text = f"<b>🩺 Medical Insight:</b><br/>{prediction_result['medical_insight']}"
        story.append(Paragraph(insight_text, self.styles['Normal']))
        story.append(Spacer(1, 12))

        # Class Probabilities Table
        prob_data = [['Class', 'Probability', 'Confidence']]
        for class_name, prob in prediction_result['class_probabilities'].items():
            prob_data.append([class_name, f"{prob:.4f}", f"{prob*100:.2f}%"])

        prob_table = Table(prob_data)
        prob_table.setStyle(TableStyle([
            ('BACKGROUND', (0, 0), (-1, 0), Color(0.8, 0.8, 0.8)),
            ('TEXTCOLOR', (0, 0), (-1, 0), Color(0, 0, 0)),
            ('ALIGN', (0, 0), (-1, -1), 'CENTER'),
            ('FONTNAME', (0, 0), (-1, 0), 'Helvetica-Bold'),
            ('FONTSIZE', (0, 0), (-1, 0), 12),
            ('BOTTOMPADDING', (0, 0), (-1, 0), 12),
            ('BACKGROUND', (0, 1), (-1, -1), Color(0.95, 0.95, 0.95)),
        ]))

        story.append(prob_table)
        story.append(Spacer(1, 12))

        # Technical Details
        tech_text = f"""
        <b>Technical Details:</b><br/>
        • Preprocessing: CLAHE + Resize to 224x224<br/>
        • Feature Extractor: ViT-B16 (64-dim features)<br/>
        • Dimensionality Reduction: PCA to 45 components<br/>
        • Classifier: CatBoost with Optuna optimization<br/>
        • Model Accuracy: {prediction_result.get('model_accuracy', 'N/A'):.2f}%<br/>
        """
        story.append(Paragraph(tech_text, self.styles['Normal']))

        doc.build(story)
        print(f"✅ PDF report saved to: {save_path}")
        return save_path

class BrainTumorClassifier:
    """Main class orchestrating the entire pipeline"""

    def __init__(self, pca_components=45):
        self.preprocessor = MRIPreprocessor()
        self.feature_extractor = ViTFeatureExtractor(output_dim=64)
        self.scaler = StandardScaler()
        self.pca = PCA(n_components=pca_components)
        self.classifier = OptunaCatBoostClassifier(n_trials=30)
        self.label_encoder = LabelEncoder()
        self.shap_explainer = None
        self.report_generator = ReportGenerator()
        self.results = {}
        self.class_mapping = {
            'notumor': 'no_tumor',
            'glioma': 'glioma',
            'meningioma': 'meningioma',
            'pituitary': 'pituitary'
        }
        self.model_trained = False

    def load_training_data(self, base_path):
        """Load training dataset"""
        print("Loading training data...")

        image_paths = []
        labels = []

        # Load from directory structure
        for class_folder in os.listdir(base_path):
            class_path = os.path.join(base_path, class_folder)
            if not os.path.isdir(class_path):
                continue

            if class_folder not in self.class_mapping:
                continue

            mapped_label = self.class_mapping[class_folder]

            for img_file in os.listdir(class_path):
                if img_file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                    img_path = os.path.join(class_path, img_file)
                    image_paths.append(img_path)
                    labels.append(mapped_label)

        print(f"Loaded {len(image_paths)} training images")
        return image_paths, labels

    def train_model(self, training_data_path):
        """Train the complete pipeline"""
        print("🚀 Starting model training pipeline...")

        # Load training data
        image_paths, labels = self.load_training_data(training_data_path)

        # Preprocess images
        print("🔄 Preprocessing images...")
        images, y = self.preprocessor.batch_preprocess(image_paths, labels)

        # Encode labels
        y_encoded = self.label_encoder.fit_transform(y)

        # Extract ViT features
        print("🎯 Extracting ViT features...")
        vit_features = self.feature_extractor.extract_features(images)

        # Scale features
        print("⚖️ Scaling features...")
        scaled_features = self.scaler.fit_transform(vit_features)

        # Apply PCA
        print("🔍 Applying PCA...")
        pca_features = self.pca.fit_transform(scaled_features)
        print(f"PCA explained variance: {self.pca.explained_variance_ratio_.sum():.3f}")

        # Train CatBoost with Optuna
        print("🤖 Training CatBoost classifier...")
        self.classifier.optimize_and_train(pca_features, y_encoded)

        # Initialize SHAP explainer
        print("🧠 Initializing SHAP explainer...")
        self.shap_explainer = SHAPExplainer(self.classifier.best_model, pca_features)

        # Store training info
        self.results = {
            'training_samples': len(images),
            'explained_variance': self.pca.explained_variance_ratio_.sum(),
            'best_params': self.classifier.best_params,
            'model_accuracy': self.classifier.study.best_value * 100
        }

        self.model_trained = True
        print("✅ Model training completed!")

    def predict_single_image(self, image_input):
        """Predict a single image with comprehensive analysis"""
        if not self.model_trained:
            raise ValueError("Model not trained yet! Call train_model() first.")

        start_time = datetime.now()

        # Preprocess image
        print("🔄 Preprocessing image...")
        processed_image = self.preprocessor.preprocess_image(image_input)

        # Extract features
        print("🎯 Extracting features...")
        vit_features = self.feature_extractor.extract_features([processed_image])
        scaled_features = self.scaler.transform(vit_features)
        pca_features = self.pca.transform(scaled_features)

        # Make prediction
        print("🤖 Making prediction...")
        prediction = self.classifier.predict(pca_features)[0]
        probabilities = self.classifier.predict_proba(pca_features)[0]

        # Get SHAP explanation
        print("🧠 Generating SHAP explanation...")
        shap_result = self.shap_explainer.explain_prediction(pca_features)

        # Decode prediction
        predicted_class = self.label_encoder.inverse_transform([prediction])[0]

        # FIX: Convert numpy array to scalar value
        if isinstance(probabilities[prediction], np.ndarray):
            confidence = float(probabilities[prediction]) * 100
        else:
            confidence = probabilities[prediction] * 100

        # Create class probabilities dictionary
        class_probabilities = {}
        for i, class_name in enumerate(self.label_encoder.classes_):
            # FIX: Ensure probabilities are scalar values
            if isinstance(probabilities[i], np.ndarray):
                class_probabilities[class_name] = float(probabilities[i])
            else:
                class_probabilities[class_name] = probabilities[i]

        # Generate medical insight
        medical_insight = self.generate_medical_insight(predicted_class, shap_result['top_features'])

        prediction_time = (datetime.now() - start_time).total_seconds()

        result = {
            'predicted_class': predicted_class,
            'confidence': confidence,
            'class_probabilities': class_probabilities,
            'shap_values': shap_result['shap_values'],
            'top_features': shap_result['top_features'],
            'medical_insight': medical_insight,
            'prediction_time': f"{prediction_time:.2f}s",
            **self.results
        }

        return result

    def generate_medical_insight(self, predicted_class, top_features):
        """Generate medical insights based on prediction and SHAP values"""

        insights = {
            'glioma': f"The model identified characteristics typical of glioma tumors. Key feature {top_features[0][0]} showed high impact, which often correlates with infiltrative growth patterns commonly seen in gliomas.",
            'meningioma': f"Features suggest meningioma characteristics. {top_features[0][0]} significantly contributed to this classification, typically associated with well-defined tumor boundaries characteristic of meningiomas.",
            'pituitary': f"Pituitary tumor features detected. The influential feature {top_features[0][0]} aligns with typical pituitary adenoma presentations in the sella turcica region.",
            'no_tumor': f"No tumor characteristics identified. Feature {top_features[0][0]} strongly supported the healthy tissue classification, indicating normal brain structure patterns."
        }

        return insights.get(predicted_class, "Classification completed with SHAP feature analysis.")

    def generate_reports(self, prediction_result, output_dir="./reports"):
        """Generate both PDF and Markdown reports"""
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")

        # Generate markdown report
        md_path = os.path.join(output_dir, f"brain_tumor_report_{timestamp}.md")
        self.report_generator.generate_markdown_report(prediction_result, md_path)

        # Generate PDF report
        pdf_path = os.path.join(output_dir, f"brain_tumor_report_{timestamp}.pdf")
        self.report_generator.generate_pdf_report(prediction_result, pdf_path)

        return md_path, pdf_path

    def display_prediction_summary(self, result):
        """Display formatted prediction summary"""
        print("\n" + "="*60)
        print("🧠 BRAIN TUMOR DETECTION RESULTS")
        print("="*60)
        print(f"🔍 **Prediction:** {result['predicted_class']}")
        print(f"📈 **Confidence Score:** {result['confidence']:.2f}%")

        top_features_str = ", ".join([f[0] for f in result['top_features'][:3]])
        print(f"🧠 **Top Influencing Features:** {top_features_str}")
        print(f"🩺 **Medical Insight:** {result['medical_insight']}")

        print(f"\n📊 **Class Probabilities:**")
        for class_name, prob in result['class_probabilities'].items():
            print(f"   {class_name}: {prob*100:.2f}%")

        print(f"\n⚙️ **Technical Details:**")
        print(f"   • Model Accuracy: {result.get('model_accuracy', 'N/A'):.2f}%")
        print(f"   • Prediction Time: {result['prediction_time']}")
        print(f"   • PCA Explained Variance: {result.get('explained_variance', 'N/A'):.3f}")
        print("="*60)

# Example usage and demo function
def demo_brain_tumor_classifier():
    """Demonstration of the brain tumor classifier"""

    print("🧠 Brain Tumor Classifier Demo")
    print("="*50)

    # Initialize classifier
    classifier = BrainTumorClassifier(pca_components=45)

    # Train model (replace with your training data path)
    training_data_path = "/content/binary_class/Training"

    try:
        # Check if training data exists
        if os.path.exists(training_data_path):
            classifier.train_model(training_data_path)
        else:
            print("⚠️ Training data path not found. Using dummy training for demo...")
            # In practice, you would provide the actual training data

        # Example prediction (replace with actual image path)
        image_path = "/content/binary_class/Testing/glioma/Te-gl_0010.jpg"  # Update this path

        if os.path.exists(image_path):
            # Make prediction
            result = classifier.predict_single_image(image_path)

            # Display results
            classifier.display_prediction_summary(result)

            # Generate reports
            md_path, pdf_path = classifier.generate_reports(result)

            print(f"\n📄 Reports generated:")
            print(f"   • Markdown: {md_path}")
            print(f"   • PDF: {pdf_path}")

        else:
            print("⚠️ Test image path not found. Please provide a valid image path for prediction.")

    except Exception as e:
        print(f"❌ Error in demo: {e}")
        import traceback
        traceback.print_exc()

def create_streamlit_app():
    """Create a Streamlit web application for brain tumor classification"""

    try:
        import streamlit as st
        from PIL import Image
        import io

        st.set_page_config(
            page_title="🧠 Brain Tumor Classifier",
            page_icon="🧠",
            layout="wide"
        )

        st.title("🧠 Brain Tumor Classification System")
        st.markdown("---")

        # Sidebar for model information
        with st.sidebar:
            st.header("📋 Model Information")
            st.info("""
            **Features:**
            - ViT-B16 Feature Extraction
            - CLAHE Preprocessing
            - PCA Dimensionality Reduction
            - CatBoost Classification
            - SHAP Explainability
            - Comprehensive Reporting
            """)

            st.header("📊 Supported Classes")
            st.write("• Glioma")
            st.write("• Meningioma")
            st.write("• Pituitary")
            st.write("• No Tumor")

        # Main content
        col1, col2 = st.columns([1, 1])

        with col1:
            st.header("📤 Upload MRI Image")
            uploaded_file = st.file_uploader(
                "Choose an MRI image...",
                type=['png', 'jpg', 'jpeg', 'bmp', 'tiff']
            )

            if uploaded_file is not None:
                # Display uploaded image
                image = Image.open(uploaded_file)
                st.image(image, caption="Uploaded MRI Image", use_column_width=True)

                # Training path input
                training_path = st.text_input(
                    "Training Data Path:",
                    placeholder="/path/to/training/data"
                )

                if st.button("🚀 Classify Image", type="primary"):
                    if not training_path:
                        st.error("Please provide the training data path!")
                        return

                    with st.spinner("🔄 Processing image and making prediction..."):
                        try:
                            # Initialize and train classifier
                            classifier = BrainTumorClassifier()

                            # Convert uploaded file to format expected by classifier
                            image_bytes = io.BytesIO()
                            image.save(image_bytes, format='PNG')
                            image_bytes.seek(0)

                            # Train model if not already trained
                            if os.path.exists(training_path):
                                classifier.train_model(training_path)
                            else:
                                st.error("Training data path not found!")
                                return

                            # Make prediction
                            result = classifier.predict_single_image(image)

                            # Store result in session state
                            st.session_state.prediction_result = result
                            st.session_state.classifier = classifier

                        except Exception as e:
                            st.error(f"Error during classification: {e}")
                            return

        with col2:
            st.header("📊 Prediction Results")

            if 'prediction_result' in st.session_state:
                result = st.session_state.prediction_result

                # Main prediction
                st.success(f"🔍 **Prediction:** {result['predicted_class']}")
                st.info(f"📈 **Confidence:** {result['confidence']:.2f}%")

                # Class probabilities
                st.subheader("📊 Class Probabilities")
                prob_df = pd.DataFrame([
                    {"Class": k, "Probability": f"{v*100:.2f}%", "Score": v}
                    for k, v in result['class_probabilities'].items()
                ]).sort_values('Score', ascending=False)

                st.dataframe(prob_df[['Class', 'Probability']], use_container_width=True)

                # Top features
                st.subheader("🧠 Top Influencing Features")
                features_text = "\n".join([
                    f"• **{feat[0]}:** {feat[1]:.4f}"
                    for feat in result['top_features'][:5]
                ])
                st.markdown(features_text)

                # Medical insight
                st.subheader("🩺 Medical Insight")
                st.write(result['medical_insight'])

                # Generate reports
                if st.button("📄 Generate Reports"):
                    classifier = st.session_state.classifier
                    with st.spinner("Generating reports..."):
                        md_path, pdf_path = classifier.generate_reports(result)

                        col_md, col_pdf = st.columns(2)
                        with col_md:
                            with open(md_path, 'r') as f:
                                st.download_button(
                                    label="📄 Download Markdown Report",
                                    data=f.read(),
                                    file_name="brain_tumor_report.md",
                                    mime="text/markdown"
                                )

                        with col_pdf:
                            with open(pdf_path, 'rb') as f:
                                st.download_button(
                                    label="📄 Download PDF Report",
                                    data=f.read(),
                                    file_name="brain_tumor_report.pdf",
                                    mime="application/pdf"
                                )
            else:
                st.info("👆 Upload an MRI image and click 'Classify Image' to see results")

        # Technical details section
        st.markdown("---")
        st.header("⚙️ Technical Details")

        tech_col1, tech_col2, tech_col3 = st.columns(3)

        with tech_col1:
            st.subheader("🔄 Preprocessing")
            st.write("• CLAHE Enhancement")
            st.write("• Resize to 224×224")
            st.write("• Normalization")

        with tech_col2:
            st.subheader("🎯 Feature Extraction")
            st.write("• ViT-B16 Pre-trained")
            st.write("• 64-dimensional features")
            st.write("• PCA to 45 components")

        with tech_col3:
            st.subheader("🤖 Classification")
            st.write("• CatBoost Classifier")
            st.write("• Optuna Optimization")
            st.write("• SHAP Explanations")

    except ImportError:
        print("Streamlit not installed. Install with: pip install streamlit")

def create_gradio_app():
    """Create a Gradio web interface for brain tumor classification"""

    try:
        import gradio as gr
        from PIL import Image

        # Global classifier instance
        global_classifier = None

        def classify_image(image, training_path):
            """Gradio classification function"""
            nonlocal global_classifier

            if image is None:
                return "Please upload an image", "", "", "", ""

            if not training_path or not os.path.exists(training_path):
                return "Invalid training data path", "", "", "", ""

            try:
                # Initialize classifier if needed
                if global_classifier is None:
                    global_classifier = BrainTumorClassifier()
                    global_classifier.train_model(training_path)

                # Make prediction
                result = global_classifier.predict_single_image(image)

                # Format results
                prediction = f"🔍 **Prediction:** {result['predicted_class']}"
                confidence = f"📈 **Confidence:** {result['confidence']:.2f}%"

                top_features = "🧠 **Top Features:**\n" + "\n".join([
                    f"• {feat[0]}: {feat[1]:.4f}"
                    for feat in result['top_features'][:5]
                ])

                probabilities = "📊 **Class Probabilities:**\n" + "\n".join([
                    f"• {k}: {v*100:.2f}%"
                    for k, v in result['class_probabilities'].items()
                ])

                medical_insight = f"🩺 **Medical Insight:**\n{result['medical_insight']}"

                return prediction, confidence, top_features, probabilities, medical_insight

            except Exception as e:
                return f"Error: {str(e)}", "", "", "", ""

        # Create Gradio interface
        with gr.Blocks(title="🧠 Brain Tumor Classifier") as app:
            gr.Markdown("# 🧠 Brain Tumor Classification System")
            gr.Markdown("Upload an MRI image for automated brain tumor classification using ViT + CatBoost + SHAP")

            with gr.Row():
                with gr.Column(scale=1):
                    image_input = gr.Image(
                        type="pil",
                        label="📤 Upload MRI Image"
                    )

                    training_path_input = gr.Textbox(
                        label="📁 Training Data Path",
                        placeholder="/path/to/training/data",
                        value=""
                    )

                    classify_btn = gr.Button(
                        "🚀 Classify Image",
                        variant="primary"
                    )

                with gr.Column(scale=1):
                    prediction_output = gr.Markdown(label="🔍 Prediction")
                    confidence_output = gr.Markdown(label="📈 Confidence")
                    features_output = gr.Markdown(label="🧠 Top Features")
                    probabilities_output = gr.Markdown(label="📊 Probabilities")
                    insight_output = gr.Markdown(label="🩺 Medical Insight")

            # Connect the classification function
            classify_btn.click(
                fn=classify_image,
                inputs=[image_input, training_path_input],
                outputs=[
                    prediction_output,
                    confidence_output,
                    features_output,
                    probabilities_output,
                    insight_output
                ]
            )

            # Add examples section
            gr.Markdown("## 📋 Technical Specifications")
            gr.Markdown("""
            - **Preprocessing:** CLAHE + Resize to 224×224
            - **Feature Extraction:** Vision Transformer (ViT-B16)
            - **Dimensionality Reduction:** PCA to 45 components
            - **Classification:** CatBoost with Optuna optimization
            - **Explainability:** SHAP feature importance analysis
            - **Supported Classes:** Glioma, Meningioma, Pituitary, No Tumor
            """)

        return app

    except ImportError:
        print("Gradio not installed. Install with: pip install gradio")
        return None

# Command-line interface
def main():
    """Main function with command-line argument support"""
    # Check if running in Colab
    try:
        import google.colab
        IN_COLAB = True
    except:
        IN_COLAB = False

    if IN_COLAB:
        # In Colab, run the demo directly
        demo_brain_tumor_classifier()
    else:
        # Outside of Colab, use argparse
        import argparse

        parser = argparse.ArgumentParser(description="🧠 Brain Tumor Classification System")
        parser.add_argument('--mode', choices=['train', 'predict', 'demo', 'streamlit', 'gradio'],
                           default='demo', help='Operation mode')
        parser.add_argument('--training_data', type=str, help='Path to training data directory')
        parser.add_argument('--image', type=str, help='Path to image for prediction')
        parser.add_argument('--output_dir', type=str, default='./reports', help='Output directory for reports')

        args = parser.parse_args()

        if args.mode == 'demo':
            demo_brain_tumor_classifier()

        elif args.mode == 'train':
            if not args.training_data:
                print("❌ Training data path required for training mode")
                return

            classifier = BrainTumorClassifier()
            classifier.train_model(args.training_data)
            print("✅ Model training completed!")

        elif args.mode == 'predict':
            if not args.training_data or not args.image:
                print("❌ Both training data path and image path required for prediction mode")
                return

            classifier = BrainTumorClassifier()
            classifier.train_model(args.training_data)

            result = classifier.predict_single_image(args.image)
            classifier.display_prediction_summary(result)

            # Generate reports
            md_path, pdf_path = classifier.generate_reports(result, args.output_dir)
            print(f"\n📄 Reports saved:")
            print(f"   • Markdown: {md_path}")
            print(f"   • PDF: {pdf_path}")

        elif args.mode == 'streamlit':
            print("🚀 Starting Streamlit app...")
            print("Run: streamlit run this_script.py --mode streamlit")
            create_streamlit_app()

        elif args.mode == 'gradio':
            print("🚀 Starting Gradio app...")
            app = create_gradio_app()
            if app:
                app.launch(share=True)

if __name__ == "__main__":
    main()

🧠 Brain Tumor Classifier Demo
Using device: cpu


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

🚀 Starting model training pipeline...
Loading training data...
Loaded 5712 training images
🔄 Preprocessing images...
Processing 5712 images...
Processed 100/5712 images (1.8%)
Processed 200/5712 images (3.5%)
Processed 300/5712 images (5.3%)
Processed 400/5712 images (7.0%)
Processed 500/5712 images (8.8%)
Processed 600/5712 images (10.5%)
Processed 700/5712 images (12.3%)
Processed 800/5712 images (14.0%)
Processed 900/5712 images (15.8%)
Processed 1000/5712 images (17.5%)
Processed 1100/5712 images (19.3%)
Processed 1200/5712 images (21.0%)
Processed 1300/5712 images (22.8%)
Processed 1400/5712 images (24.5%)
Processed 1500/5712 images (26.3%)
Processed 1600/5712 images (28.0%)
Processed 1700/5712 images (29.8%)
Processed 1800/5712 images (31.5%)
Processed 1900/5712 images (33.3%)
Processed 2000/5712 images (35.0%)
Processed 2100/5712 images (36.8%)
Processed 2200/5712 images (38.5%)
Processed 2300/5712 images (40.3%)
Processed 2400/5712 images (42.0%)
Processed 2500/5712 images (43.

[I 2025-07-30 05:01:07,653] A new study created in memory with name: no-name-40a4131e-7da3-41ff-964c-241e0cfe8e07


Processed batch 357/357
⚖️ Scaling features...
🔍 Applying PCA...
PCA explained variance: 0.942
🤖 Training CatBoost classifier...
Starting Optuna hyperparameter optimization...


  0%|          | 0/30 [00:00<?, ?it/s]

[I 2025-07-30 05:02:17,779] Trial 0 finished with value: 0.8731408573928259 and parameters: {'iterations': 437, 'depth': 10, 'learning_rate': 0.22227824312530747, 'l2_leaf_reg': 6.387926357773329, 'border_count': 66}. Best is trial 0 with value: 0.8731408573928259.
[I 2025-07-30 05:02:20,524] Trial 1 finished with value: 0.8608923884514436 and parameters: {'iterations': 240, 'depth': 4, 'learning_rate': 0.2611910822747312, 'l2_leaf_reg': 6.41003510568888, 'border_count': 190}. Best is trial 0 with value: 0.8731408573928259.
[I 2025-07-30 05:02:41,915] Trial 2 finished with value: 0.8705161854768154 and parameters: {'iterations': 118, 'depth': 10, 'learning_rate': 0.2514083658321223, 'l2_leaf_reg': 2.9110519961044856, 'border_count': 72}. Best is trial 0 with value: 0.8731408573928259.
[I 2025-07-30 05:02:46,838] Trial 3 finished with value: 0.8696412948381452 and parameters: {'iterations': 265, 'depth': 6, 'learning_rate': 0.16217936517334897, 'l2_leaf_reg': 4.887505167779041, 'border_