In [None]:
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
import os
import firebase_admin
from firebase_admin import credentials, firestore
from google.cloud.firestore import SERVER_TIMESTAMP
from utils import predict_age_children, predict_age_adults
from pydantic import BaseModel
from typing import List
import json
import uvicorn
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

app = FastAPI(title="Age Prediction API", 
             description="API for predicting age from images with Firebase integration")

# CORS configuration
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Configuration
UPLOAD_DIR = "uploads"
os.makedirs(UPLOAD_DIR, exist_ok=True)
ALLOWED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/jpg"]
MAX_FILE_SIZE = 10 * 1024 * 1024  # 10MB

# Initialize Firebase
if not firebase_admin._apps:
    try:
        cred = credentials.Certificate("age-pred-firebase-adminsdk-fbsvc-8b4f39e55b.json")
        firebase_admin.initialize_app(cred)
    except Exception as e:
        raise RuntimeError(f"Firebase initialization failed: {str(e)}")

db = firestore.client()

# Firestore collection names
IMAGES_COLLECTION = "image_predictions"
FEEDBACK_COLLECTION = "user_feedback"

class UpdateAgeRequest(BaseModel):
    image_id: str
    is_correct: bool
    corrected_ages: List[str]

@app.post("/predict-age/", 
         summary="Predict age from image",
         response_description="Returns age predictions and image ID")
async def predict_age(
    file: UploadFile = File(..., description="Image file to process"),
    category: str = Form(..., description="Either 'child' or 'adult'"),
    request: Request = None
):
    """
    Process an image to predict age and store results in Firebase.
    
    - **file**: Image file (JPEG/PNG)
    - **category**: 'child' or 'adult' for different prediction models
    """
    try:
        # Validate file
        if not file.content_type in ALLOWED_IMAGE_TYPES:
            raise HTTPException(400, "Only JPEG or PNG images are allowed")
        
        # Read and validate file size
        image_bytes = await file.read()
        if len(image_bytes) > MAX_FILE_SIZE:
            raise HTTPException(400, "File too large (max 10MB)")
        if not image_bytes:
            raise HTTPException(400, "Empty file received")

        # Save image temporarily
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        file_path = os.path.join(UPLOAD_DIR, f"{timestamp}_{file.filename}")
        with open(file_path, "wb") as buffer:
            buffer.write(image_bytes)

        # Get predictions
        try:
            if category.lower() == "child":
                predictions, image_base64 = predict_age_children(file_path)
            else:
                predictions, image_base64 = predict_age_adults(file_path)
        except Exception as pred_error:
            raise HTTPException(500, f"Prediction failed: {str(pred_error)}")

        # Get user IP
        user_ip = request.client.host if request.client else "UNKNOWN"

        # Store data in Firestore
        try:
            doc_ref = db.collection(IMAGES_COLLECTION).document()
            doc_data = {
                "user_ip": user_ip,
                "original_filename": file.filename,
                "image_size": len(image_bytes),
                "category": category.lower(),
                "age_prediction": predictions,
                "is_correct": None,
                "corrected_age": None,
                "created_at": SERVER_TIMESTAMP,
                "image_id": doc_ref.id
            }
            doc_ref.set(doc_data)
        except Exception as firestore_error:
            raise HTTPException(500, f"Firestore error: {str(firestore_error)}")

        # Clean up temporary file
        try:
            os.remove(file_path)
        except:
            pass

        return {
            "status": "success",
            "predictions": predictions,
            "image_base64": image_base64,
            "image_id": doc_ref.id
        }

    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(500, f"Unexpected error: {str(e)}")

@app.post("/update-age/", 
         summary="Update age prediction feedback",
         response_description="Confirmation message")
async def update_age(request: UpdateAgeRequest):
    """
    Update the correctness status of a prediction.
    
    - **image_id**: The Firestore document ID
    - **is_correct**: Whether the prediction was accurate
    - **corrected_ages**: List of corrected age values
    """
    try:
        # Validate input
        if not request.image_id:
            raise HTTPException(400, "Image ID is required")
        if not isinstance(request.corrected_ages, list):
            raise HTTPException(400, "corrected_ages must be a list")

        # Update the document in Firestore
        doc_ref = db.collection(IMAGES_COLLECTION).document(request.image_id)
        doc = doc_ref.get()
        
        if not doc.exists:
            raise HTTPException(404, "Image record not found")

        try:
            # Prepare update data
            update_data = {
                "is_correct": request.is_correct,
                "corrected_age": request.corrected_ages,
                "updated_at": SERVER_TIMESTAMP
            }

            # Store feedback separately
            feedback_ref = db.collection(FEEDBACK_COLLECTION).document()
            original_data = doc.to_dict()
            feedback_data = {
                "image_id": request.image_id,
                "original_prediction": original_data.get("age_prediction"),
                "is_correct": request.is_correct,
                "corrected_ages": request.corrected_ages,
                "feedback_time": SERVER_TIMESTAMP,
                "original_category": original_data.get("category")
            }
            feedback_ref.set(feedback_data)

            # Update the original document
            doc_ref.update(update_data)
        except Exception as firestore_error:
            raise HTTPException(500, f"Firestore update failed: {str(firestore_error)}")

        return {"status": "success", "message": "Prediction updated successfully"}

    except HTTPException:
        raise
    except Exception as e:
        raise HTTPException(500, f"Unexpected error: {str(e)}")

if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)