In [None]:
# Import necessary libraries
import streamlit as st
import requests
import numpy as np
import cv2
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from tensorflow.keras.models import load_model

# Load the trained model
model = load_model('unet_model.h5')  # Load the best model

# FastAPI App
app = FastAPI()

@app.post("/predict/")
async def predict(image: bytes):
    nparr = np.frombuffer(image, np.uint8)
    img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (256, 256))
    img = img.astype('float32') / 255.0
    img = img[np.newaxis, ..., np.newaxis]  # Reshape for model input
    
    prediction = model.predict(img)
    prediction = (prediction > 0.5).astype(np.uint8)  # Binarize prediction
    return JSONResponse(content={"prediction": prediction.tolist()})

# Streamlit app
st.title("Brain MRI Metastasis Segmentation")
uploaded_file = st.file_uploader("Upload an MRI image", type=["tiff", "png", "jpg"])

if uploaded_file is not None:
    # Display uploaded image
    image = np.array(cv2.imdecode(np.frombuffer(uploaded_file.read(), np.uint8), cv2.IMREAD_GRAYSCALE))
    st.image(image, caption='Uploaded MRI Image', use_column_width=True)
    
    # Call FastAPI endpoint
    response = requests.post("http://localhost:8000/predict/", files={"image": uploaded_file.getvalue()})
    
    if response.status_code == 200:
        prediction = np.array(response.json()['prediction'])
        st.image(prediction.reshape(256, 256), caption='Predicted Segmentation', use_column_width=True)
    else:
        st.error("Prediction failed.")
