In [1]:
import os
import numpy as np
import cv2
from PIL import Image
from tensorflow.keras.models import load_model # type: ignore
from ultralytics import YOLO
from IPython.display import display
import ipywidgets as widgets
import io
import shutil

# Load the classification model
try:
    classification_model = load_model('models/model.h5')
    print("Classification model loaded successfully.")
except Exception as e:
    print(f"Error loading classification model: {e}")

# Load the segmentation model
try:
    segmentation_model = YOLO("models/best.pt")
    print("Segmentation model loaded successfully.")
except Exception as e:
    print(f"Error loading segmentation model: {e}")

# Create temporary directories for uploads and segmentation results
UPLOAD_FOLDER = 'temp/uploads'
SEGMENTATION_FOLDER = 'temp/segmentation_results'
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(SEGMENTATION_FOLDER, exist_ok=True)

# Global variable to store the uploaded file path
uploaded_file_path = None

# Classification function
def classify_image(image_data):
    try:
        img = Image.open(io.BytesIO(image_data))
        img_resized = cv2.resize(np.array(img), (150, 150))
        img_reshaped = img_resized.reshape(1, 150, 150, 3)
        prediction = classification_model.predict(img_reshaped)
        class_index = np.argmax(prediction, axis=1)[0]
        classes = ['Glioma Tumor', 'No Tumor', 'Meningioma Tumor', 'Pituitary Tumor']
        return classes[class_index]
    except Exception as e:
        return f"Classification Error: {str(e)}"

# Segmentation function
def segment_image(image_path):
    try:
        results = segmentation_model(image_path, save=True)
        output_dir = "runs/segment/predict"
        segmented_image_path = os.path.join(output_dir, os.path.basename(image_path))
        
        if os.path.exists(segmented_image_path):
            final_path = os.path.join(SEGMENTATION_FOLDER, os.path.basename(image_path))
            shutil.move(segmented_image_path, final_path)
            return final_path
        else:
            return f"Error: Segmented image not found in {output_dir}"
    except Exception as e:
        return f"Segmentation Error: {str(e)}"

# Create widgets
upload_button = widgets.FileUpload(accept='image/*', multiple=False)
predict_button = widgets.Button(description="Predict", button_style='success')
classification_output = widgets.Output()
segmentation_output = widgets.Output()

# Callback for handling image uploads
def handle_upload(change):
    global uploaded_file_path
    classification_output.clear_output()
    segmentation_output.clear_output()

    # Ensure a file is uploaded
    uploaded_file = upload_button.value
    if not uploaded_file:
        with classification_output:
            print("No file uploaded. Please upload an image first.")
        return
    
    # Handle file data dynamically
    try:
        file_info = list(uploaded_file.values())[0] if isinstance(uploaded_file, dict) else uploaded_file[0]
        file_data = file_info['content']
        file_name = file_info.get('name', 'uploaded_image.jpg')  # Default name if 'name' is unavailable
        uploaded_file_path = os.path.join(UPLOAD_FOLDER, file_name)

        with open(uploaded_file_path, 'wb') as f:
            f.write(file_data)
        with classification_output:
            print("Image uploaded successfully. Click 'Predict' to process.")
            display(Image.open(io.BytesIO(file_data)))
    except Exception as e:
        with classification_output:
            print(f"Error handling upload: {e}")

# Callback for handling predictions
def handle_predict(change):
    if uploaded_file_path is None:
        with classification_output:
            classification_output.clear_output()
            print("Please upload an image first.")
        return

    with classification_output:
        classification_output.clear_output()
        print("Processing image for classification...")
    with segmentation_output:
        segmentation_output.clear_output()
        print("Processing image for segmentation...")

    # Perform classification
    with open(uploaded_file_path, 'rb') as f:
        file_data = f.read()
    classification_result = classify_image(file_data)
    with classification_output:
        classification_output.clear_output()
        print(f"Classification Result: {classification_result}")

    # Perform segmentation
    segmented_path = segment_image(uploaded_file_path)
    with segmentation_output:
        segmentation_output.clear_output()
        if os.path.exists(segmented_path):
            print("Segmentation completed. Result displayed below:")
            display(Image.open(segmented_path))
        else:
            print(segmented_path)

# Attach callbacks
upload_button.observe(handle_upload, names='value')
predict_button.on_click(handle_predict)

# Display the interface
display(widgets.VBox([
    widgets.Label("Upload an Image for Classification and Segmentation:"),
    upload_button,
    predict_button,
    widgets.Label("Classification Result:"),
    classification_output,
    widgets.Label("Segmentation Result:"),
    segmentation_output
]))




Classification model loaded successfully.
Segmentation model loaded successfully.


VBox(children=(Label(value='Upload an Image for Classification and Segmentation:'), FileUpload(value=(), accep…