In [None]:
from pycocotools.coco import COCO
import os
import torch
from torchvision import models,transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image

# Dataset

In [None]:
class COCODataset(Dataset):
    def __init__(self, root_dir, annotation_file, transform=None):
        self.root_dir = root_dir
        self.coco = COCO(annotation_file)
        self.image_ids = list(self.coco.imgs.keys())  # List of image IDs
        self.transform = transform
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.root_dir, img_info['file_name'])
        image = Image.open(img_path).convert('RGB')
        
        # Get annotations
        annotations = self.coco.getAnnIds(imgIds=img_id)
        annotations = self.coco.loadAnns(annotations)
        num_objects = len(annotations)  # This is the count of objects in the image
        
        if self.transform:
            image = self.transform(image)
        
        return image, num_objects

In [None]:
# Data transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [None]:
# Set your paths
train_images_dir = "data/COCO/train2017"
train_annotations_dir = "data/COCO/annotations/instances_train2017.json"

# Create dataset and dataloader
train_dataset = COCODataset(train_images_dir, train_annotations_dir, transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

In [None]:
# Sanity check

print("train_dataset size:", len(train_dataset))
print("train_loader size:", len(train_dataloader))

In [None]:
# Feature extraction

# Load ResNet pre-trained model
resnet = models.resnet50(pretrained=True)
resnet = torch.nn.Sequential(*list(resnet.children())[:-1])  # Remove final layer
resnet.eval()  # Set to evaluation mode

# Example of feature extraction from an image batch
def extract_features(images):
    with torch.no_grad():
        features = resnet(images)  # [batch_size, 2048, 1, 1]
        features = features.view(features.size(0), -1)  # Flatten to [batch_size, 2048]
    return features

