In order to download only a subset of the COCO dataset we need to:

1- Download the whole dataset

2- Create a dataset that select only some categories and it copies the images into a folder "data"

3- Move the data folder to your HPC 

In [2]:
import os
import json
import requests
from tqdm import tqdm
from pycocotools.coco import COCO
import shutil


"""The followed approach:
    1.Chose a subset of clases form COCO by sepcifying the class ids', download the annotations and then only download the images that fit the classes specified"""
# Configuration
config = {
    "categories_to_keep": [84, 31, 52],  # book, wine glass, banana
    "min_area_threshold": 655,
    "data_dir": "data"
}

# URLs for COCO annotation files
annotation_urls = {
    "train": "http://images.cocodataset.org/annotations/annotations_trainval2017.zip",
    "test": "http://images.cocodataset.org/annotations/image_info_test2017.zip"
}

# Create data directory
os.makedirs(config["data_dir"], exist_ok=True)

def download_file(url, destination):
    """Download a file with progress bar"""
    if os.path.exists(destination):
        print(f"File already exists: {destination}")
        return
    
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024
    
    with open(destination, 'wb') as f, tqdm(
        desc=os.path.basename(destination),
        total=total_size,
        unit='B',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in response.iter_content(block_size):
            f.write(data)
            bar.update(len(data))

def extract_zip(zip_path, extract_to):
    """Extract a zip file"""
    import zipfile
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    print(f"Extracted {zip_path} to {extract_to}")

def download_and_extract_annotations(split):
    """Download and extract annotation files for given split"""
    zip_path = os.path.join(config["data_dir"], f"{split}_annotations.zip")
    
    # Download annotation zip
    download_file(annotation_urls[split], zip_path)
    
    # Extract annotations
    extract_zip(zip_path, config["data_dir"])
    
    # Return path to extracted annotation file
    if split == "train":
        return os.path.join(config["data_dir"], "annotations", f"instances_train2017.json")
    elif split == "test":
        return os.path.join(config["data_dir"], "annotations", f"image_info_test2017.json")
    else:  # val
        return os.path.join(config["data_dir"], "annotations", f"instances_val2017.json")

def download_images_for_categories(split, annotation_file, categories_to_keep, min_area_threshold):
    """Download only images containing objects from specified categories"""
    # Load annotations
    coco = COCO(annotation_file)
    
    # Get image IDs that contain objects from specified categories
    image_ids = []
    for cat_id in categories_to_keep:
        ann_ids = coco.getAnnIds(catIds=[cat_id], iscrowd=False)
        anns = coco.loadAnns(ann_ids)
        valid_anns = [ann for ann in anns if ann['area'] >= min_area_threshold]
        img_ids = list(set([ann['image_id'] for ann in valid_anns]))
        image_ids.extend(img_ids)
    
    # Remove duplicates
    image_ids = list(set(image_ids))
    print(f"Found {len(image_ids)} images with categories {categories_to_keep}")
    
    # Create output directory
    output_dir = os.path.join(config["data_dir"], split)
    os.makedirs(output_dir, exist_ok=True)
    
    # Download images
    for img_id in tqdm(image_ids, desc=f"Downloading {split} images"):
        img_info = coco.loadImgs(img_id)[0]
        file_name = img_info["file_name"]
        
        # COCO image URL format
        img_url = f"http://images.cocodataset.org/train2017/{file_name}" if split == "train" else \
                  f"http://images.cocodataset.org/val2017/{file_name}" if split == "val" else \
                  f"http://images.cocodataset.org/test2017/{file_name}"
        
        # Destination path
        dst_path = os.path.join(output_dir, file_name)
        
        # Download if not already exists
        if not os.path.exists(dst_path):
            os.makedirs(os.path.dirname(dst_path), exist_ok=True)
            try:
                download_file(img_url, dst_path)
            except Exception as e:
                print(f"Error downloading {img_url}: {e}")
    
    # Save list of image IDs for future reference
    with open(os.path.join(output_dir, "image_ids.json"), "w") as f:
        json.dump(image_ids, f)
    
    print(f"Downloaded {split} images to {output_dir}")
    return output_dir

# Process train split
print("Processing train split...")
train_annotation_file = download_and_extract_annotations("train")
train_dir = download_images_for_categories("train", train_annotation_file, 
                                          config["categories_to_keep"], 
                                          config["min_area_threshold"])

# Process validation split
print("Processing validation split...")
val_annotation_file = train_annotation_file.replace("train", "val")
val_dir = download_images_for_categories("val", val_annotation_file, 
                                        config["categories_to_keep"], 
                                        config["min_area_threshold"])

# Update config for future use
config["train_image_dir"] = train_dir
config["val_image_dir"] = val_dir
config["train_annotation_file"] = train_annotation_file
config["val_annotation_file"] = val_annotation_file

# Save updated config
with open("config.json", "w") as f:
    json.dump(config, f, indent=2)

print("Finished downloading dataset. Config file updated.")

Processing train split...


train_annotations.zip: 100%|██████████| 241M/241M [00:07<00:00, 34.4MB/s] 


Extracted data/train_annotations.zip to data
loading annotations into memory...
Done (t=16.99s)
creating index...
index created!
Found 10267 images with categories [84, 31, 52]


000000163840.jpg: 100%|██████████| 179k/179k [00:00<00:00, 86.3MB/s]
000000131074.jpg: 100%|██████████| 168k/168k [00:00<00:00, 12.3MB/s]98it/s]
000000294914.jpg: 100%|██████████| 101k/101k [00:00<00:00, 95.6MB/s]95it/s]
000000458756.jpg: 100%|██████████| 95.4k/95.4k [00:00<00:00, 76.3MB/s]it/s]
000000294918.jpg: 100%|██████████| 122k/122k [00:00<00:00, 80.3MB/s]28it/s]
000000229383.jpg: 100%|██████████| 153k/153k [00:00<00:00, 93.1MB/s]25it/s]
000000294920.jpg: 100%|██████████| 114k/114k [00:00<00:00, 72.3MB/s]36it/s]
000000196619.jpg: 100%|██████████| 52.5k/52.5k [00:00<00:00, 64.4MB/s]it/s]
000000163852.jpg: 100%|██████████| 122k/122k [00:00<00:00, 74.6MB/s]04it/s]
000000262162.jpg: 100%|██████████| 128k/128k [00:00<00:00, 74.7MB/s]76it/s]
000000458772.jpg: 100%|██████████| 79.7k/79.7k [00:00<00:00, 77.9MB/s]5it/s]
000000294933.jpg: 100%|██████████| 145k/145k [00:00<00:00, 41.0MB/s].23it/s]
000000229401.jpg: 100%|██████████| 241k/241k [00:00<00:00, 97.1MB/s].34it/s]
000000360473.jpg

Downloaded train images to data/train
Processing validation split...
loading annotations into memory...
Done (t=0.43s)
creating index...
index created!
Found 443 images with categories [84, 31, 52]


000000577539.jpg: 100%|██████████| 161k/161k [00:00<00:00, 63.5MB/s]
000000184324.jpg: 100%|██████████| 117k/117k [00:00<00:00, 87.2MB/s]/s]
000000575500.jpg: 100%|██████████| 142k/142k [00:00<00:00, 437kB/s]t/s]
000000479248.jpg: 100%|██████████| 194k/194k [00:00<00:00, 1.83MB/s]/s]
000000407574.jpg: 100%|██████████| 234k/234k [00:00<00:00, 93.7MB/s]/s]
000000331799.jpg: 100%|██████████| 278k/278k [00:00<00:00, 98.6MB/s]/s]
000000387098.jpg: 100%|██████████| 134k/134k [00:00<00:00, 83.7MB/s]/s]
000000161820.jpg: 100%|██████████| 191k/191k [00:00<00:00, 86.6MB/s]/s]
000000024610.jpg: 100%|██████████| 96.0k/96.0k [00:00<00:00, 65.0MB/s]]
000000421923.jpg: 100%|██████████| 165k/165k [00:00<00:00, 49.7MB/s]/s]
000000401446.jpg: 100%|██████████| 208k/208k [00:00<00:00, 62.1MB/s]t/s]
000000368684.jpg: 100%|██████████| 101k/101k [00:00<00:00, 47.4MB/s]t/s]
000000542776.jpg: 100%|██████████| 80.3k/80.3k [00:00<00:00, 46.4MB/s]s]
000000147518.jpg: 100%|██████████| 153k/153k [00:00<00:00, 71.5M

Downloaded val images to data/val
Finished downloading dataset. Config file updated.


In [5]:
import json
import os
import numpy as np
import matplotlib.patches as patches
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from pycocotools.coco import COCO
from PIL import Image
import os
import numpy as np
import cv2  # OpenCV needed for polygon drawing
from pycocotools import mask as coco_mask
from matplotlib import pyplot as plt

class CocoSegmentationDatasetMRCNN(Dataset):
    def __init__(self, image_dir, seg_annotation_file, categories_to_keep=[1], min_area_threshold=655, output_dir=None):
        self.image_dir = image_dir
        self.coco_seg = COCO(seg_annotation_file)
        self.min_area_threshold = min_area_threshold
        self.categories_to_keep = categories_to_keep
        self.output_dir = output_dir
        
        # Filter images to keep only those containing objects from specified categories
        self.image_ids = []
        for cat_id in self.categories_to_keep:
            ann_ids = self.coco_seg.getAnnIds(catIds=[cat_id], iscrowd=False)
            anns = self.coco_seg.loadAnns(ann_ids)
            valid_anns = [ann for ann in anns if ann['area'] >= self.min_area_threshold]
            img_ids = list(set([ann['image_id'] for ann in valid_anns]))
            self.image_ids.extend(img_ids)
        
        # Remove duplicates
        self.image_ids = list(set(self.image_ids))
        print(f"Dataset contains {len(self.image_ids)} images with categories {categories_to_keep}")
        
        # For visualization, create a category mapping
        self.category_map = {}
        for cat_id in self.categories_to_keep:
            cat_info = self.coco_seg.loadCats(cat_id)[0]
            self.category_map[cat_id] = cat_info['name']
        
        # Copy images to output directory if specified
        if output_dir:
            self.copy_images_to_output_dir()

    def copy_images_to_output_dir(self):
        """Copy filtered images to the output directory"""
        if not self.output_dir:
            return
            
        # Create output directory
        os.makedirs(self.output_dir, exist_ok=True)
        
        print(f"Copying {len(self.image_ids)} images to {self.output_dir}...")
        copied_count = 0
        
        for img_id in self.image_ids:
            # Get image info
            img_info = self.coco_seg.loadImgs(img_id)[0]
            
            # Source and destination paths
            src_path = os.path.join(self.image_dir, img_info["file_name"])
            dst_path = os.path.join(self.output_dir, img_info["file_name"])
            
            # Copy image if source exists
            if os.path.exists(src_path):
                os.makedirs(os.path.dirname(dst_path), exist_ok=True)
                shutil.copy(src_path, dst_path)
                copied_count += 1
            else:
                print(f"Warning: Could not find {src_path}")
        
        print(f"Copied {copied_count} images to {self.output_dir}")

    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        
        # Load image
        image_info = self.coco_seg.loadImgs(image_id)[0]
        image_path = os.path.join(self.image_dir, image_info["file_name"])
        image = Image.open(image_path).convert("RGB")
        
        # Convert to tensor
        image = transforms.ToTensor()(image)
        
        # Load annotations
        ann_ids = self.coco_seg.getAnnIds(imgIds=image_id, catIds=self.categories_to_keep, iscrowd=False)
        anns = self.coco_seg.loadAnns(ann_ids)
        
        # Initialize target dictionary
        target = {}
        boxes = []
        masks = []
        labels = []
        category_ids = []  # Keep original category IDs for reference
        
        # Process each annotation
        for ann in anns:
            if ann['area'] < self.min_area_threshold:
                continue
                
            # Get bounding box
            bbox = ann['bbox']  # [x, y, width, height] format
            # Convert to [x1, y1, x2, y2] format
            boxes.append([bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]])
            
            # Get mask
            mask = self.coco_seg.annToMask(ann)
            masks.append(torch.as_tensor(mask, dtype=torch.uint8))
            
            # Keep original category ID for reference
            category_ids.append(ann['category_id'])
            
            # For segmentation only, use class 1 for all foreground objects
            labels.append(1)  # 1 for foreground, 0 for background
        
        # Convert to tensor format
        if boxes:
            target["boxes"] = torch.as_tensor(boxes, dtype=torch.float32)
            target["labels"] = torch.as_tensor(labels, dtype=torch.int64)
            target["masks"] = torch.stack(masks)
            target["category_ids"] = torch.as_tensor(category_ids, dtype=torch.int64)  # original IDs for reference
        else:
            # Empty annotations
            target["boxes"] = torch.zeros((0, 4), dtype=torch.float32)
            target["labels"] = torch.zeros((0), dtype=torch.int64)
            target["masks"] = torch.zeros((0, image.shape[1], image.shape[2]), dtype=torch.uint8)
            target["category_ids"] = torch.zeros((0), dtype=torch.int64)
        
        target["image_id"] = torch.tensor([image_id])
        
        return image, target

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
# Using the downloaded subset with your existing dataset class
dataset_train = CocoSegmentationDatasetMRCNN(
    config["train_image_dir"],
    config["train_annotation_file"],
    categories_to_keep=config["categories_to_keep"],
    min_area_threshold=config["min_area_threshold"]
)

dataset_val = CocoSegmentationDatasetMRCNN(
    config["val_image_dir"],
    config["val_annotation_file"],
    categories_to_keep=config["categories_to_keep"],
    min_area_threshold=config["min_area_threshold"]
)

loading annotations into memory...
Done (t=16.93s)
creating index...
index created!
Dataset contains 10267 images with categories [84, 31, 52]
loading annotations into memory...
Done (t=0.53s)
creating index...
index created!
Dataset contains 443 images with categories [84, 31, 52]
