In [29]:
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.models.detection import maskrcnn_resnet50_fpn
import numpy as np
import gradio as gr

In [30]:
def refine_masks(mask, threshold=0.5):
    """Refine the masks based on threshold and remove small regions."""
    mask = mask > threshold
    mask = mask.float()
    return mask

In [31]:
def calculate_metrics(pred_mask, true_mask):
    """Calculate accuracy, precision, recall, F1-score, TP, FP, TN, FN."""
    pred_mask_flat = pred_mask.view(-1)
    true_mask_flat = true_mask.view(-1)

    tp = (pred_mask_flat * true_mask_flat).sum().item()  # True Positives
    fp = (pred_mask_flat * (1 - true_mask_flat)).sum().item()  # False Positives
    tn = ((1 - pred_mask_flat) * (1 - true_mask_flat)).sum().item()  # True Negatives
    fn = ((1 - pred_mask_flat) * true_mask_flat).sum().item()  # False Negatives

    # Calculate metrics
    accuracy = (tp + tn) / (tp + fp + tn + fn) if (tp + fp + tn + fn) > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return accuracy, precision, recall, f1_score, tp, fp, tn, fn

In [32]:
def remove_background_advanced(image):
    """Remove background from the uploaded image using Mask R-CNN."""
    original_image = Image.fromarray(image).convert("RGB")
    
    # Define the transformation
    transform = T.Compose([
        T.ToTensor(),
        T.Resize((1024, 1024))  # Increasing the resolution for better accuracy
    ])
    image_tensor = transform(original_image)

    # Load the pre-trained Mask R-CNN model
    model = maskrcnn_resnet50_fpn(weights='DEFAULT')  # Update to use weights parameter
    model.eval()

    # Move image to the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image_tensor = image_tensor.to(device)
    model.to(device)

    # Perform inference
    with torch.no_grad():
        prediction = model([image_tensor])

    # Extract the most confident mask
    scores = prediction[0]['scores']
    masks = prediction[0]['masks']
    highest_score_index = scores.argmax()
    mask = masks[highest_score_index, 0]

    # Refine the mask
    mask = refine_masks(mask, threshold=0.8)  # Adjust threshold for better results

    # Apply the refined mask to the image
    masked_image = image_tensor * mask

    # Convert the tensor back to an image
    final_image = T.ToPILImage()(masked_image.cpu()).convert("RGB")
    original_image = original_image.resize((1024, 1024))  # Resize for consistency
    mask_image = T.ToPILImage()(mask.cpu()).convert("L")  # Convert mask to grayscale

    # Simulate a true mask for demonstration purposes (this should be replaced with actual ground truth)
    true_mask = torch.zeros_like(mask)  # Replace with actual true mask if available
    true_mask[mask > 0] = 1  # Assuming the object is present in the mask

    # Calculate metrics
    accuracy, precision, recall, f1_score, tp, fp, tn, fn = calculate_metrics(mask.cpu(), true_mask.cpu())
    
    # Print metrics to the console
    print(f"Accuracy: {accuracy:.2f}, Precision: {precision:.2f}, Recall: {recall:.2f}, F1-Score: {f1_score:.2f}")
    print(f"TP: {tp}, FP: {fp}, TN: {tn}, FN: {fn}")

    return final_image, original_image, mask_image

In [33]:
def gradio_interface(image):
    """Gradio interface function to process the uploaded image."""
    final_image, original_image, mask_image = remove_background_advanced(image)
    return final_image, original_image, mask_image


In [34]:
# Create Gradio interface
iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Image(type="numpy"),  # Single input for the image
    outputs=[gr.Image(type="pil"), gr.Image(type="pil"), gr.Image(type="pil")],  # Remove metric outputs
    title="Background Removal with Mask R-CNN",
    description="Upload an image to remove its background using Mask R-CNN. The process includes the original image and the mask."
)

In [35]:
# Launch the interface
iface.launch()

* Running on local URL:  http://127.0.0.1:7864

To create a public link, set `share=True` in `launch()`.




Accuracy: 1.00, Precision: 1.00, Recall: 1.00, F1-Score: 1.00
TP: 163889.0, FP: 0.0, TN: 884687.0, FN: 0.0
Accuracy: 1.00, Precision: 1.00, Recall: 1.00, F1-Score: 1.00
TP: 380798.0, FP: 0.0, TN: 667778.0, FN: 0.0
Accuracy: 1.00, Precision: 1.00, Recall: 1.00, F1-Score: 1.00
TP: 343161.0, FP: 0.0, TN: 705415.0, FN: 0.0
Accuracy: 1.00, Precision: 1.00, Recall: 1.00, F1-Score: 1.00
TP: 612375.0, FP: 0.0, TN: 436201.0, FN: 0.0
Accuracy: 1.00, Precision: 1.00, Recall: 1.00, F1-Score: 1.00
TP: 173848.0, FP: 0.0, TN: 874728.0, FN: 0.0
Accuracy: 1.00, Precision: 1.00, Recall: 1.00, F1-Score: 1.00
TP: 175362.0, FP: 0.0, TN: 873214.0, FN: 0.0
