In [None]:
!pip install torch torchvision matplotlib


In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np

# Load the pre-trained Mask R-CNN model
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model.eval()  # Set the model to evaluation mode

# Define a transformation to resize, normalize and convert images to tensor
transform = transforms.Compose([
    transforms.ToTensor(),
])

# Function to get segmented parts of an image
def get_segmented_parts(image_url):
    # Load the image from the URL
    response = requests.get(image_url)
    img = Image.open(BytesIO(response.content)).convert('RGB')
    
    # Apply the transformations
    img_t = transform(img)
    batch_t = torch.unsqueeze(img_t, 0)
    
    # Pass the image through the model to get the detections
    with torch.no_grad():
        detections = model(batch_t)
    
    return detections[0], img

# Function to visualize the segmented parts
def visualize_segmented_parts(detections, img, threshold=0.5):
    # Get the masks, labels, and scores
    masks = detections['masks']
    labels = detections['labels']
    scores = detections['scores']
    
    # Create a plot
    fig, ax = plt.subplots(1, figsize=(12, 9))
    ax.imshow(img)
    
    # Iterate over the masks
    for i in range(masks.shape[0]):
        if scores[i] >= threshold:
            mask = masks[i, 0].mul(255).byte().cpu().numpy()
            label = labels[i].item()
            score = scores[i].item()
            
            # Create a colored mask
            colored_mask = np.zeros_like(mask, dtype=np.uint8)
            colored_mask[mask > 128] = 255
            
            # Overlay the mask on the image
            ax.imshow(np.dstack((colored_mask, colored_mask, colored_mask)), alpha=0.5)
            ax.text(detections['boxes'][i][0], detections['boxes'][i][1], f'{label} ({score:.2f})', 
                    bbox=dict(facecolor='yellow', alpha=0.5), fontsize=10, color='black')
    
    plt.axis('off')
    plt.show()

# Example usage
image_url = 'https://www.russellandbromley.co.uk/ccstore/v1/images/?source=/file/v5011827438628642521/products/242458_xlalt4.jpg&height=800&width=800&quality=1.0'
detections, img = get_segmented_parts(image_url)
visualize_segmented_parts(detections, img)
