In [5]:
import os
import sys
import torch
import cv2
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from datasets import load_dataset
from pycocotools.coco import COCO
from PIL import Image
from huggingface_hub import hf_hub_download

In [10]:
HF_REPO_NAME = "peaceAsh/fashion_sam_dataset_v2"
COCO_DATASET = "peaceAsh/fashion_seg_coco_dataset"
JSON_FILE = "result.json"

In [47]:
class FashionSAMDataset(Dataset):
    def __init__(self, hf_dataset, coco_api, image_size=1024, max_points=32):
        self.hf_dataset = hf_dataset
        self.coco_api = coco_api
        self.image_size = image_size
        self.max_points = max_points

        coco_images = self.coco_api.loadImgs(self.coco_api.getImgIds())
        self.filename_to_id = {img['file_name']: img['id'] for img in coco_images}
        self.filenames = sorted(self.filename_to_id.keys())
        self.idx_to_filename = {i: fname for i, fname in enumerate(self.filenames)}

    def __len__(self):
        return len(self.hf_dataset)

    def _sample_points_from_ann(self, ann, scale_x, scale_y, original_w, original_h):
        individual_mask = self.coco_api.annToMask(ann)
        # Erode to avoid edge points
        kernel = np.ones((3, 3), np.uint8)
        eroded_mask = cv2.erode(individual_mask, kernel, iterations=1)
        coords = np.argwhere(eroded_mask > 0)  # rows, cols

        if len(coords) == 0:
            return []

        # pick between 1 and 3 points
        num_points = np.random.randint(1, 4)
        num_to_pick = min(num_points, len(coords))
        selected_coords = coords[np.random.choice(len(coords), size=num_to_pick, replace=False)]

        pts = []
        for yx in selected_coords:
            row, col = int(yx[0]), int(yx[1])   # row=y, col=x
            x_resized = int(col * scale_x)
            y_resized = int(row * scale_y)
            # clamp to resized image bounds
            x_resized = max(0, min(self.image_size - 1, x_resized))
            y_resized = max(0, min(self.image_size - 1, y_resized))
            pts.append([x_resized, y_resized])
        return pts

    def __getitem__(self, idx):
        entry = self.hf_dataset[idx]
        image_pil = entry['image'].convert("RGB")
        mask_pil = entry['mask'].convert("L")

        image_np = np.array(image_pil)          # H, W, 3
        mask_np = np.array(mask_pil)            # H, W

        original_h, original_w = image_np.shape[:2]

        image_resized = cv2.resize(image_np, (self.image_size, self.image_size))
        mask_resized = cv2.resize(mask_np, (self.image_size, self.image_size),
                                  interpolation=cv2.INTER_NEAREST)

        # compute scale factors actually used in resize
        scale_x = self.image_size / float(original_w)
        scale_y = self.image_size / float(original_h)

        # Gather points from COCO annotations (1-3 points per instance)
        points = []
        filename = self.idx_to_filename.get(idx) 
        img_id = self.filename_to_id.get(filename)
        if img_id is not None:
            ann_ids = self.coco_api.getAnnIds(imgIds=img_id)
            anns = self.coco_api.loadAnns(ann_ids)
            for ann in anns:
                pts = self._sample_points_from_ann(ann, scale_x, scale_y, original_w, original_h)
                points.extend(pts)

        if len(points) > self.max_points:
            points = points[:self.max_points]

        # create padded points array of shape (max_points, 2), pad with -1
        pts_arr = np.full((self.max_points, 2), -1, dtype=np.float32)
        point_labels = np.zeros((self.max_points,), dtype=np.int64)  
        for i, p in enumerate(points):
            pts_arr[i, 0] = p[0]
            pts_arr[i, 1] = p[1]
            point_labels[i] = 1

        image_tensor = torch.tensor(image_resized).permute(2, 0, 1).to(torch.uint8)  # (3, H, W)
        mask_tensor = torch.tensor(mask_resized).unsqueeze(0).to(torch.uint8)       # (1, H, W)
        points_tensor = torch.tensor(pts_arr, dtype=torch.float32)                 # (max_points, 2)
        point_labels_tensor = torch.tensor(point_labels, dtype=torch.int64)       # (max_points,)

        return {
            "image": image_tensor,
            "mask": mask_tensor,
            "points": points_tensor,
            "point_labels": point_labels_tensor
        }


In [8]:
downloaded_coco_path = hf_hub_download(
    repo_id=COCO_DATASET,
    filename=JSON_FILE,
    repo_type="dataset" 
)

result.json: 0.00B [00:00, ?B/s]

In [11]:
fashion_ds = load_dataset(HF_REPO_NAME, split='train')
coco = COCO(downloaded_coco_path)

loading annotations into memory...
Done (t=0.00s)
creating index...
index created!


In [48]:
train_val_split = fashion_ds.train_test_split(test_size=0.1)
train_hf_dataset = train_val_split['train']
val_hf_dataset = train_val_split['test']

train_dataset = FashionSAMDataset(train_hf_dataset, coco)
val_dataset = FashionSAMDataset(val_hf_dataset, coco)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)