In [None]:
from fastapi import FastAPI, File, UploadFile
import cv2
import numpy as np
from tensorflow.keras.models import load_model

# Replace with path to your downloaded model
model_path = "path/to/colorization_model.h5"

# Load the colorization model
model = load_model(model_path)

app = FastAPI()


def preprocess_image(image_bytes):
    """Preprocesses the uploaded image for colorization."""
    # Read image from bytes
    image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_GRAYSCALE)
    # Resize image (optional, based on model requirements)
    # You can uncomment and adjust width and height based on your model
    # image = cv2.resize(image, (width, height))

    # Add additional preprocessing steps if needed (e.g., normalization)
    image = np.expand_dims(image, axis=0)  # Add batch dimension
    return image


def colorize_image(image):
    """Colorizes the grayscale image using the model."""
    # Predict colorization with the model
    colorized = model.predict(image)[0]
    # Process the predicted color channels (adjust based on model output)
    colorized = np.argmax(colorized, axis=-1)
    # Convert predicted channels back to RGB image (assuming model outputs separate channels)
    colorized = cv2.cvtColor(colorized, cv2.COLOR_LAB2BGR)
    return colorized


@app.post("/upload")
async def upload_image(image: UploadFile = File(...)):
    """Handles image upload, preprocessing, colorization, and returns the result."""
    # Read image content
    image_bytes = await image.read()

    try:
        # Preprocess image
        grayscale_image = preprocess_image(image_bytes)

        # Colorize the image
        colorized_image = colorize_image(grayscale_image)

        # Encode colorized image as bytes for response (adjust based on format)
        _, colorized_image_buffer = cv2.imencode(".jpg", colorized_image)
        return {"colorized_image": colorized_image_buffer.tobytes()}

    except Exception as e:
        # Handle errors gracefully (e.g., invalid image format)
        return {"error": str(e)}


if __name__ == "__main__":
    import uvicorn
    uvicorn.run("main:app", host="0.0.0.0", port=8000)
