In [None]:
import nest_asyncio
import uvicorn
from fastapi import FastAPI, UploadFile, File
from fastapi.responses import JSONResponse
import numpy as np
import rasterio
import tempfile
import io
from tensorflow.keras.models import load_model
import segmentation_models as sm
import base64
import matplotlib.pyplot as plt
from PIL import Image
import requests
nest_asyncio.apply()
custom_objects = {
    'iou_score': sm.metrics.iou_score,
    'dice_loss': sm.losses.DiceLoss
}
model = load_model("best_model.h5", custom_objects=custom_objects)
app = FastAPI()

def preprocess_raster(file_bytes):
    with tempfile.NamedTemporaryFile(delete=False, suffix=".tif") as tmp_file:
        tmp_file.write(file_bytes)
        tmp_path = tmp_file.name

    with rasterio.open(tmp_path) as src:
        image = src.read()  # shape (12, H, W)

    normalized = np.empty_like(image, dtype=np.float32)
    for i in range(image.shape[0]):
        band = image[i]
        normalized[i] = (band - band.min()) / (band.max() - band.min() + 1e-5)

    # Swap bands 1 and 3
    normalized[[1, 3]] = normalized[[3, 1]]

    # Extract indices
    green = normalized[2]
    swir1 = normalized[6]
    swir2 = normalized[5]
    nir = normalized[4]
    red = normalized[3]

    mndwi = (green - swir1) / (green + swir1 + 1e-6)
    ndwi = (green - nir) / (green + nir + 1e-6)
    ndvi = (nir - red) / (nir + red + 1e-6)

    indices = np.stack([
        np.clip(mndwi, -1, 1),
        np.clip(ndwi, -1, 1),
        np.clip(ndvi, -1, 1)
    ], axis=0)

    nine_band_input = np.concatenate([
        normalized[1:4],
        indices,
        normalized[5:8]
    ], axis=0)

    nine_band_input = np.transpose(nine_band_input, (1, 2, 0))

    if nine_band_input.shape[:2] != (128, 128):
        raise ValueError(f"Expected image shape (128, 128), got {nine_band_input.shape[:2]}")

    return np.expand_dims(nine_band_input, axis=0)
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    try:
        contents = await file.read()
        x = preprocess_raster(contents)  # shape: (1, 128, 128, 9)

        pred = model.predict(x)[0]  # shape: (128, 128, 1) or (128, 128)
        binary_pred = (pred > 0.5).astype(np.uint8).squeeze()  # shape: (128, 128)

        # Convert to image (PIL) for easy serialization
        img_pil = Image.fromarray(binary_pred * 255).convert("L")

        # Encode to base64 string
        buffer = io.BytesIO()
        img_pil.save(buffer, format="PNG")
        img_bytes = buffer.getvalue()
        img_base64 = base64.b64encode(img_bytes).decode("utf-8")
        num_channels = x.shape[-1]
        cols = 6
        rows = (num_channels + cols - 1) // cols
        fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
        for i in range(rows * cols):
            ax = axes[i // cols, i % cols] if rows > 1 else axes[i % cols]
            if i < num_channels:
                ax.imshow(x[0, :, :, i], cmap='RdYlGn')
                ax.set_title(f"Channel {i}")
            else:
                ax.axis('off')
            ax.axis("off")

        plt.tight_layout()
        fig_buffer = io.BytesIO()
        plt.savefig(fig_buffer, format="PNG", bbox_inches='tight')
        plt.close(fig)
        fig_base64 = base64.b64encode(fig_buffer.getvalue()).decode("utf-8")
        return JSONResponse({
            "status": "success",
            "image_base64": img_base64,
            "channels_base64": fig_base64,
            "prediction_shape": binary_pred.shape
        })

    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=400)
    
    
uvicorn.run(app, host="0.0.0.0", port=8000)

Segmentation Models: using `keras` framework.


INFO:     Started server process [23152]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)


INFO:     127.0.0.1:50230 - "POST /predict HTTP/1.1" 200 OK
INFO:     127.0.0.1:50287 - "POST /predict HTTP/1.1" 200 OK
INFO:     127.0.0.1:50316 - "POST /predict HTTP/1.1" 200 OK
INFO:     127.0.0.1:50323 - "POST /predict HTTP/1.1" 200 OK
INFO:     127.0.0.1:50329 - "POST /predict HTTP/1.1" 200 OK
INFO:     127.0.0.1:50338 - "POST /predict HTTP/1.1" 200 OK
INFO:     127.0.0.1:50347 - "POST /predict HTTP/1.1" 200 OK
INFO:     127.0.0.1:50355 - "POST /predict HTTP/1.1" 200 OK


INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [23152]
