In [2]:
# Install dependencies
!pip install -U torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
!pip install 'git+https://github.com/facebookresearch/detectron2.git'
!pip install gradio


Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu113
Collecting git+https://github.com/facebookresearch/detectron2.git
  Cloning https://github.com/facebookresearch/detectron2.git to /tmp/pip-req-build-e3n1e6n0
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/detectron2.git /tmp/pip-req-build-e3n1e6n0
  Resolved https://github.com/facebookresearch/detectron2.git to commit b1c43ffbc995426a9a6b5c667730091a384e0fa4
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting yacs>=0.1.8 (from detectron2==0.6)
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting fvcore<0.1.6,>=0.1.5 (from detectron2==0.6)
  Downloading fvcore-0.1.5.post20221221.tar.gz (50 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.2/50.2 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting iopath<0.1.10,>=0.1.7 (from detectron2==0.6)
  Do

Collecting gradio
  Downloading gradio-5.9.1-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.6-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.5.0-py3-none-any.whl.metadata (3.0 kB)
Collecting gradio-client==1.5.2 (from gradio)
  Downloading gradio_client-1.5.2-py3-none-any.whl.metadata (7.1 kB)
Collecting markupsafe~=2.0 (from gradio)
  Downloading MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart>=0.0.18 (from gradio)
  Downloading python_multipart-0.0.20-py3-none-any.whl.metadata (1.8 kB)
Collecting ruff>=0.2.2 (from gradio)
  Downloading ruff-0.8.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metad

In [1]:
import torch
import torchvision.transforms as T
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.models.detection import maskrcnn_resnet50_fpn
import numpy as np
import gradio as gr

In [2]:
def refine_masks(mask, threshold=0.5):
    """Refine the masks based on threshold and remove small regions."""
    mask = mask > threshold
    mask = mask.float()
    return mask

def remove_background_advanced(image):
    # Load the image
    image = Image.fromarray(image).convert("RGB")

    # Define the transformation
    transform = T.Compose([
        T.ToTensor(),
        T.Resize((1024, 1024))  # Increasing the resolution for better accuracy
    ])
    image_tensor = transform(image)

    # Load the pre-trained Mask R-CNN model
    model = maskrcnn_resnet50_fpn(pretrained=True)
    model.eval()

    # Move image to the device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    image_tensor = image_tensor.to(device)
    model.to(device)

    # Perform inference
    with torch.no_grad():
        prediction = model([image_tensor])

    # Extract the most confident mask
    scores = prediction[0]['scores']
    masks = prediction[0]['masks']
    highest_score_index = scores.argmax()
    mask = masks[highest_score_index, 0]

    # Refine the mask
    mask = refine_masks(mask, threshold=0.8)  # Adjust threshold for better results

    # Apply the refined mask to the image
    masked_image = image_tensor * mask

    # Convert the tensor back to an image
    return T.ToPILImage()(masked_image.cpu()).convert("RGB")

In [3]:
# Gradio interface
def gradio_interface(image):
    result_image = remove_background_advanced(image)
    return result_image

In [4]:
# Create Gradio interface
iface = gr.Interface(
    fn=gradio_interface,
    inputs=gr.Image(type="numpy"),
    outputs=gr.Image(type="pil"),
    title="Background Removal with Mask R-CNN",
    description="Upload an image to remove its background using Mask R-CNN."
)

In [5]:
# Launch the interface
iface.launch()

Running Gradio in a Colab notebook requires sharing enabled. Automatically setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://f38dcf9ea80ec6725a.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




In [6]:
import torch

def calculate_metrics(pred_mask, true_mask):
    """Calculate accuracy, precision, recall, and F1-score."""
    # Calculate True Positives, False Positives, True Negatives, False Negatives
    tp = torch.logical_and(pred_mask, true_mask).sum().item()
    fp = torch.logical_and(pred_mask, torch.logical_not(true_mask)).sum().item()
    tn = torch.logical_and(torch.logical_not(pred_mask), torch.logical_not(true_mask)).sum().item()
    fn = torch.logical_and(torch.logical_not(pred_mask), true_mask).sum().item()

    # Calculate Accuracy
    accuracy = (tp + tn) / (tp + fp + tn + fn)

    # Calculate Precision
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0

    # Calculate Recall
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0

    # Calculate F1 Score
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return accuracy, precision, recall, f1_score

# Example placeholders for predicted_mask and true_mask as torch tensors of the same size
predicted_mask = torch.tensor([[1, 0], [0, 1]])  # Placeholder tensor
true_mask = torch.tensor([[1, 0], [0, 0]])  # Placeholder tensor

accuracy, precision, recall, f1_score = calculate_metrics(predicted_mask, true_mask)
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1_score:.4f}")


Accuracy: 0.7500
Precision: 0.5000
Recall: 1.0000
F1 Score: 0.6667
