In order to download only a subset of the COCO dataset we need to:

1- Download the whole dataset

2- Create a dataset that select only some categories and it copies the images into a folder "data"

3- Move the data folder to your HPC 

In [None]:
import fiftyone as fo
import fiftyone.zoo as foz

# Download the COCO-2017 validation split and load it into FiftyOne
dataset = foz.load_zoo_dataset("coco-2017", split="validation")

# Give the dataset a new name, and make it persistent
dataset.name = "coco-2017-validation"
dataset.persistent = True


In [None]:
dataset_train = foz.load_zoo_dataset("coco-2017", split="train")
dataset_train.name = "coco-2017-train"
dataset_train.persistent = True

In [None]:
dataset_test = foz.load_zoo_dataset("coco-2017", split="test")
dataset_test.name = "coco-2017-test"
dataset_test.persistent = True

In [None]:
import json
import os
import numpy as np
import matplotlib.patches as patches
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from pycocotools.coco import COCO
from PIL import Image
import os
import numpy as np
import cv2  # OpenCV needed for polygon drawing
from pycocotools import mask as coco_mask
from matplotlib import pyplot as plt
import shutil

config = json.load(open("config.json"))

In [None]:
class CocoSegmentationDatasetMRCNN(Dataset):
    def __init__(self, image_dir, seg_annotation_file, categories_to_keep=[1], min_area_threshold=655, output_dir=None):
        self.image_dir = image_dir
        self.coco_seg = COCO(seg_annotation_file)
        self.min_area_threshold = min_area_threshold
        self.categories_to_keep = categories_to_keep
        self.output_dir = output_dir
        
        # Filter images to keep only those containing objects from specified categories
        self.image_ids = []
        for cat_id in self.categories_to_keep:
            ann_ids = self.coco_seg.getAnnIds(catIds=[cat_id], iscrowd=False)
            anns = self.coco_seg.loadAnns(ann_ids)
            valid_anns = [ann for ann in anns if ann['area'] >= self.min_area_threshold]
            img_ids = list(set([ann['image_id'] for ann in valid_anns]))
            self.image_ids.extend(img_ids)
        
        # Remove duplicates
        self.image_ids = list(set(self.image_ids))
        print(f"Dataset contains {len(self.image_ids)} images with categories {categories_to_keep}")
        
        # For visualization, create a category mapping
        self.category_map = {}
        for cat_id in self.categories_to_keep:
            cat_info = self.coco_seg.loadCats(cat_id)[0]
            self.category_map[cat_id] = cat_info['name']
        
        # Copy images to output directory if specified
        if output_dir:
            self.copy_images_to_output_dir()

    def copy_images_to_output_dir(self):
        """Copy filtered images to the output directory"""
        if not self.output_dir:
            return
            
        # Create output directory
        os.makedirs(self.output_dir, exist_ok=True)
        
        print(f"Copying {len(self.image_ids)} images to {self.output_dir}...")
        copied_count = 0
        
        for img_id in self.image_ids:
            # Get image info
            img_info = self.coco_seg.loadImgs(img_id)[0]
            
            # Source and destination paths
            src_path = os.path.join(self.image_dir, img_info["file_name"])
            dst_path = os.path.join(self.output_dir, img_info["file_name"])
            
            # Copy image if source exists
            if os.path.exists(src_path):
                os.makedirs(os.path.dirname(dst_path), exist_ok=True)
                shutil.copy(src_path, dst_path)
                copied_count += 1
            else:
                print(f"Warning: Could not find {src_path}")
        
        print(f"Copied {copied_count} images to {self.output_dir}")

    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        
        # Load image
        image_info = self.coco_seg.loadImgs(image_id)[0]
        image_path = os.path.join(self.image_dir, image_info["file_name"])
        image = Image.open(image_path).convert("RGB")
        
        # Convert to tensor
        image = transforms.ToTensor()(image)
        
        # Load annotations
        ann_ids = self.coco_seg.getAnnIds(imgIds=image_id, catIds=self.categories_to_keep, iscrowd=False)
        anns = self.coco_seg.loadAnns(ann_ids)
        
        # Initialize target dictionary
        target = {}
        boxes = []
        masks = []
        labels = []
        category_ids = []  # Keep original category IDs for reference
        
        # Process each annotation
        for ann in anns:
            if ann['area'] < self.min_area_threshold:
                continue
                
            # Get bounding box
            bbox = ann['bbox']  # [x, y, width, height] format
            # Convert to [x1, y1, x2, y2] format
            boxes.append([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]])
            
            # Get mask
            mask = self.coco_seg.annToMask(ann)
            masks.append(torch.as_tensor(mask, dtype=torch.uint8))
            
            # Keep original category ID for reference
            category_ids.append(ann['category_id'])
            
            # For segmentation only, use class 1 for all foreground objects
            labels.append(1)  # 1 for foreground, 0 for background
        
        # Convert to tensor format
        if boxes:
            target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
            target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
            target["masks"] = torch.stack(masks)
            target["category_ids"] = torch.as_tensor(category_ids, dtype=torch.int64)  # original IDs for reference
        else:
            # Empty annotations
            target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
            target["labels"] = torch.zeros((0), dtype=torch.int64)
            target["masks"] = torch.zeros((0, image.shape[1], image.shape[2]), dtype=torch.uint8)
            target["category_ids"] = torch.zeros((0), dtype=torch.int64)
        
        target["image_id"] = torch.tensor([image_id])
        
        return image, target

In [34]:
dataset_train = CocoSegmentationDatasetMRCNN(
    config["train_image_dir"],
    config["train_annotation_file"],
    categories_to_keep=[5, 23, 53,33, 13], 
    min_area_threshold=655,
    output_dir="data/train"
)

dataset_val = CocoSegmentationDatasetMRCNN(
    config["val_image_dir"],
    config["val_annotation_file"],
    categories_to_keep=[5, 23, 53, 33,13],
    min_area_threshold=655,
    output_dir="data/validation"
)

loading annotations into memory...
Done (t=5.75s)
creating index...
index created!
Dataset contains 8514 images with categories [5, 23, 53, 33, 13]
Copying 8514 images to data/train...
Copied 8514 images to data/train
loading annotations into memory...
Done (t=0.45s)
creating index...
index created!
Dataset contains 330 images with categories [5, 23, 53, 33, 13]
Copying 330 images to data/validation...
Copied 330 images to data/validation


In [None]:
def get_prediction(model, image_tensor, device, threshold=0.5):
    model.eval()
    with torch.no_grad():
        prediction = model([image_tensor.to(device)])[0]
    
    # Filter predictions based on score
    keep = prediction['scores'] > threshold
    
    filtered_prediction = {
        'boxes': prediction['boxes'][keep],
        'labels': prediction['labels'][keep],
        'scores': prediction['scores'][keep],
        'masks': prediction['masks'][keep]
    }
    
    return filtered_prediction

def visualize_prediction(image, prediction, dataset):
    """Visualize model prediction"""
    # Convert image to numpy
    image_np = image.permute(1, 2, 0).numpy()
    
    # Get components
    masks = prediction['masks'].cpu().numpy()
    scores = prediction['scores'].cpu().numpy()
    labels = prediction['labels'].cpu().numpy()
    
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    
    # Original image
    ax[0].imshow(image_np)
    ax[0].set_title('Original Image')
    ax[0].axis('off')
    
    # Prediction
    ax[1].imshow(image_np * 0.5)  # Dim the original image
    ax[1].set_title('Predicted Masks')
    ax[1].axis('off')
    
    # Generate random colors
    n_instances = len(masks)
    colors = np.random.rand(n_instances, 3)
    
    # Draw each mask
    for i in range(n_instances):
        # Get mask (first channel, threshold at 0.5)
        mask = masks[i, 0] > 0.5
        color = colors[i]
        
        # Create colored mask
        colored_mask = np.zeros_like(image_np)
        for c in range(3):
            colored_mask[:, :, c] = mask * color[c]
        
        # Overlay on image
        mask_area = mask > 0
        masked_img = np.where(np.repeat(mask_area[:, :, np.newaxis], 3, axis=2),
                            image_np * 0.7 + colored_mask * 0.3,
                            ax[1].get_array() * 0.7)
        ax[1].imshow(masked_img)
        
        # Add label info if available
        label_id = labels[i]
        if hasattr(dataset, 'category_map') and label_id in dataset.category_map:
            label_name = dataset.category_map[label_id]
        else:
            label_name = f"Class {label_id}"
        
        # Find center of mass for text placement
        y_indices, x_indices = np.where(mask)
        if len(y_indices) > 0 and len(x_indices) > 0:
            x_center = int(np.mean(x_indices))
            y_center = int(np.mean(y_indices))
            
            # Add label text
            ax[1].text(x_center, y_center, 
                      f"{label_name}: {scores[i]:.2f}",
                      color='white', fontsize=8, 
                      ha='center', va='center',
                      bbox=dict(facecolor=color, alpha=0.7, pad=1))
    
    plt.tight_layout()
    plt.show()

# Test on a validation image
with torch.no_grad():
    # Get an image from validation set
    image, _ = dataset_val[0]
    
    # Get prediction
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    prediction = get_prediction(model, image, device, threshold=0.7)
    
    # Visualize
    visualize_prediction(image, prediction, dataset_val)