In [3]:
import gradio as gr
from PIL import Image

# Import the functions from your other file
import torch
from torchvision import transforms
from PIL import Image
import segmentation_models_pytorch as smp

# ===================================================================
# 1. CONSTANTS AND DATABASES
# ===================================================================

# These would be loaded from your meta.json files
PART_CLASSES = 16  # Total number of part classes + 1 for background
DAMAGE_CLASSES = 6 # Total number of damage classes + 1 for background

PART_ID_TO_NAME = {1: "bumper", 2: "door", 5: "headlight"}
DAMAGE_ID_TO_NAME = {1: "scratch", 2: "dent", 3: "cracked"}

COST_DATABASE = {
    "Toyota Camry": {
        "bumper": {"scratch": 150, "dent": 300, "cracked": 800, "replacement": 800},
        "door": {"scratch": 200, "dent": 500, "cracked": 1200, "replacement": 1200},
        "headlight": {"scratch": 50, "dent": 250, "cracked": 400, "replacement": 400}
    },
    "BMW X5": {
        "bumper": {"scratch": 400, "dent": 900, "cracked": 2000, "replacement": 2000},
        "door": {"scratch": 600, "dent": 1500, "cracked": 3000, "replacement": 3000},
        "headlight": {"scratch": 150, "dent": 700, "cracked": 1200, "replacement": 1200}
    }
}

# ===================================================================
# 2. HELPER FUNCTIONS
# ===================================================================

def prepare_image(image_pil):
    """Prepares a user-uploaded PIL image for the model."""
    transform = transforms.Compose([
        transforms.Resize((320, 320)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    return transform(image_pil).unsqueeze(0)

def get_prediction_mask(model, image_tensor):
    """Gets a raw prediction from a model."""
    with torch.no_grad():
        output = model(image_tensor)
    return torch.argmax(output, dim=1).squeeze()

# ===================================================================
# 3. MAIN LOGIC FUNCTIONS
# ===================================================================

def load_models(part_model_path, damage_model_path):
    """Loads the two trained models."""
    part_model = smp.MAnet(encoder_name="resnet50", classes=PART_CLASSES)
    part_model.load_state_dict(torch.load(part_model_path, map_location="cpu"))
    part_model.eval()

    damage_model = smp.MAnet(encoder_name="resnet50", classes=DAMAGE_CLASSES)
    damage_model.load_state_dict(torch.load(damage_model_path, map_location="cpu"))
    damage_model.eval()
    
    return part_model, damage_model

def calculate_final_quote(car_model, parts_mask, damages_mask):
    """Calculates the final cost based on the predicted masks."""
    total_cost = 0
    cost_breakdown = []
    damage_analysis = []
    
    unique_part_ids = torch.unique(parts_mask).tolist()
    is_damaged_mask = (damages_mask > 0)

    for part_id in unique_part_ids:
        if part_id == 0: continue
        part_name = PART_ID_TO_NAME.get(part_id, "unknown_part")
        is_current_part_mask = (parts_mask == part_id)

        part_area = torch.sum(is_current_part_mask).item()
        if part_area == 0: continue
        
        damaged_area = torch.sum(is_current_part_mask & is_damaged_mask).item()
        damage_percentage = (damaged_area / part_area) * 100
        damage_analysis.append(f"- {part_name.title()}: {damage_percentage:.1f}% damaged")

        if damage_percentage > 50.0:
            cost = COST_DATABASE[car_model][part_name]["replacement"]
            cost_breakdown.append(f"- {part_name.title()} needs REPLACEMENT: ${cost}")
            total_cost += cost
        else:
            unique_damage_ids = torch.unique(damages_mask[is_current_part_mask]).tolist()
            for damage_id in unique_damage_ids:
                if damage_id == 0: continue
                damage_name = DAMAGE_ID_TO_NAME.get(damage_id, "unknown_damage")
                if damage_name in COST_DATABASE[car_model][part_name]:
                    cost = COST_DATABASE[car_model][part_name][damage_name]
                    cost_breakdown.append(f"- {part_name.title()} has {damage_name}: ${cost}")
                    total_cost += cost
    
    return f"${total_cost}", "\n".join(cost_breakdown), "\n".join(damage_analysis)

# ===================================================================
# 1. LOAD MODELS ONCE AT STARTUP
# ===================================================================

print("Loading models...")
PART_MODEL_PATH = "MANet Model.pth"
DAMAGE_MODEL_PATH = "MANet Model1.pth"
part_model, damage_model = load_models(PART_MODEL_PATH, DAMAGE_MODEL_PATH)
print("Models loaded successfully.")

# ===================================================================
# 2. DEFINE THE MAIN PREDICTION FUNCTION FOR GRADIO
# ===================================================================

def get_cost_estimate(image_pil, car_model_str):
    """
    The main function that powers the Gradio app.
    Takes user inputs and returns the final estimate.
    """
    if image_pil is None or car_model_str is None:
        return "$0", "Please upload an image and select a model.", ""

    # Prepare the image
    image_tensor = prepare_image(image_pil)

    # Get predictions from both models
    parts_mask = get_prediction_mask(part_model, image_tensor)
    damages_mask = get_prediction_mask(damage_model, image_tensor)
    
    # Calculate the final quote
    total_cost, cost_breakdown, damage_analysis = calculate_final_quote(
        car_model_str, parts_mask, damages_mask
    )

    return total_cost, cost_breakdown, damage_analysis

# ===================================================================
# 3. CREATE AND LAUNCH THE GRADIO INTERFACE
# ===================================================================

demo = gr.Interface(
    fn=get_cost_estimate,
    inputs=[
        gr.Image(type="pil", label="Upload Car Image"),
        gr.Dropdown(list(COST_DATABASE.keys()), label="Select Car Model")
    ],
    outputs=[
        gr.Textbox(label="Total Estimated Cost"),
        gr.Textbox(label="Cost Breakdown"),
        gr.Textbox(label="Damage Analysis")
    ],
    title="Car Damage Repair Cost Estimator",
    description="Upload an image of a damaged car and select the model to get a repair cost estimate.",
    allow_flagging="never",
    examples=[
        ["examples/sample_car1.jpg", "Toyota Camry"],
        ["examples/sample_car2.jpg", "BMW X5"]
    ]
)

if __name__ == "__main__":
    demo.launch()



* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.




In [1]:
pip install gradio


Collecting gradioNote: you may need to restart the kernel to use updated packages.

  Downloading gradio-5.47.2-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<25.0,>=22.0 (from gradio)
  Downloading aiofiles-24.1.0-py3-none-any.whl.metadata (10 kB)
Collecting brotli>=1.1.0 (from gradio)
  Downloading Brotli-1.1.0-cp312-cp312-win_amd64.whl.metadata (5.6 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.117.1-py3-none-any.whl.metadata (28 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.6.1-py3-none-any.whl.metadata (2.9 kB)
Collecting gradio-client==1.13.3 (from gradio)
  Downloading gradio_client-1.13.3-py3-none-any.whl.metadata (7.1 kB)
Collecting groovy~=0.1 (from gradio)
  Downloading groovy-0.1.2-py3-none-any.whl.metadata (6.1 kB)
Collecting huggingface-hub<2.0,>=0.33.5 (from gradio)
  Downloading huggingface_hub-0.35.1-py3-none-any.whl.metadata (14 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1