In [1]:
import os
import sys
import cv2
import json
import math 
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from datasets import load_dataset 
from pycocotools.coco import COCO
from PIL import Image
from huggingface_hub import hf_hub_download

sys.path.append(os.path.abspath(os.path.join(os.getcwd(),"..")))

In [2]:
from src.preprocess_data import load_and_split_dataset

In [3]:
HF_REPO_NAME = "peaceAsh/fashion_sam_dataset_v2"
COCO_DATASET = "peaceAsh/fashion_seg_coco_dataset"
JSON_FILE = "result.json"

In [4]:
fashion_ds = load_and_split_dataset(HF_REPO_NAME)

In [5]:
fashion_ds

DatasetDict({
    train: Dataset({
        features: ['image', 'mask', 'filename'],
        num_rows: 13
    })
    validation: Dataset({
        features: ['image', 'mask', 'filename'],
        num_rows: 1
    })
    test: Dataset({
        features: ['image', 'mask', 'filename'],
        num_rows: 2
    })
})

In [6]:
coco_path = hf_hub_download(
    repo_id=COCO_DATASET,
    filename=JSON_FILE,
    repo_type="dataset" 
)

coco = COCO(coco_path)

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [7]:
coco_imgs = coco.loadImgs(coco.getImgIds())
filename_to_ids = {img['file_name'] : img['id'] for img in coco_imgs}

In [8]:
def create_instance_list(dataset,coco,filename_to_ids):
    instance_list = []
    for i in tqdm(range(len(dataset))):
        item = dataset[i]
        base_filename = item['filename']
        coco_filename = base_filename.split('/')[-1]
        img_id = filename_to_ids[coco_filename]
        if img_id is not None:
            anns_ids = coco.getAnnIds(imgIds=img_id)
            anns = coco.loadAnns(anns_ids)
            for ann in anns:
                instance_list.append({
                    "dataset_idx":i,
                    "annotation":ann
                })
    return instance_list



train_instances = create_instance_list(fashion_ds['train'],coco,filename_to_ids)
val_instances = create_instance_list(fashion_ds['validation'],coco,filename_to_ids)

print(f"Number of images:{len(fashion_ds['train'])}")
print(f"Total training instances:{len(train_instances)}")
print(f"Number of images:{len(fashion_ds['validation'])}")
print(f"Total validation instances:{len(val_instances)}")

  0%|          | 0/13 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Number of images:13
Total training instances:37
Number of images:1
Total validation instances:4


In [19]:
class FashionSAMDataset(Dataset):
    def __init__(self,dataset,instance_list,coco_api,image_size=1024,num_pts_per_instance=3):
        self.dataset = dataset
        self.instance_list = instance_list
        self.coco_api = coco_api
        self.image_size = image_size
        self.num_pts_per_instance = num_pts_per_instance

    def __len__(self):
        return len(self.instance_list)
    
    def _resize_and_pad(self,image,is_mask=False):
        target_size = self.image_size
        if is_mask:
            h,w = image.shape
            scale = target_size / max(h,w)
            new_h , new_w = int(h*scale), int(w*scale)
            resized_image = cv2.resize(image.astype(np.uint8), (new_w, new_h), interpolation=cv2.INTER_NEAREST)
            padded_image = np.zeros((target_size, target_size), dtype=np.float32)
        else:
            image.thumbnail((target_size, target_size), Image.Resampling.LANCZOS)
            new_w, new_h = image.size
            scale = target_size / max(image.size)
            padded_image = Image.new("RGB", (target_size, target_size), (0, 0, 0))
            
        top = (target_size - new_h) // 2
        left = (target_size - new_w) // 2

        if is_mask:
            padded_image[top:top + new_h, left:left + new_w] = resized_image
        else:
            padded_image.paste(image, (left, top))
            padded_image = np.array(padded_image, dtype=np.uint8)

        return padded_image, (scale, left, top)
    
    def _sample_points(self, mask):
        if np.max(mask) == 0:
            return np.empty((0, 2), dtype=np.float32)

        kernel = np.ones((3, 3), np.uint8)
        eroded_mask = cv2.erode(mask, kernel, iterations=1)
        
        foreground_coords = np.argwhere(eroded_mask > 0)
        if len(foreground_coords) == 0: 
            foreground_coords = np.argwhere(mask > 0)
        
        if len(foreground_coords) == 0:
            return np.empty((0, 2), dtype=np.float32)

        num_available = len(foreground_coords)
        points_to_sample = min(self.num_pts_per_instance, num_available)
        
        sampled_indices = np.random.choice(num_available, points_to_sample, replace=False)
        coords = foreground_coords[sampled_indices][:, ::-1].astype(np.float32)
        return coords
    
    def __getitem__(self,idx):

        instance_info = self.instance_list[idx]
        dataset_idx = instance_info['dataset_idx']
        annotation = instance_info['annotation']
        image_pil = self.dataset[dataset_idx]['image'].convert("RGB")
        instance_mask = self.coco_api.annToMask(annotation)

        image_padded, (scale, pad_left, pad_top) = self._resize_and_pad(image_pil, is_mask=False)
        mask_padded, _ = self._resize_and_pad(instance_mask, is_mask=True)
        points = self._sample_points(instance_mask)
        if points.shape[0] > 0:
            points = points * scale
            points[:, 0] += pad_left
            points[:, 1] += pad_top

        image_tensor = torch.tensor(image_padded).permute(2, 0, 1).float()
        mask_tensor = torch.tensor(mask_padded).unsqueeze(0).float()
        points_tensor = torch.tensor(points).unsqueeze(1) 
        labels_tensor = torch.ones(points_tensor.shape[0], dtype=torch.int64)

        return {
            "image": image_tensor,
            "mask": mask_tensor,
            "points": points_tensor,
            "point_labels": labels_tensor
        }

In [20]:
def sam_collate_fn(batch):
    images = torch.stack([item['image'] for item in batch])
    masks = torch.stack([item['mask'] for item in batch])
    points = [item['points'] for item in batch]
    point_labels = [item['point_labels'] for item in batch]
    
    return {
        'image': images,
        'mask': masks,
        'points': points,
        'point_labels': point_labels
    }

In [21]:
train_dataset = FashionSAMDataset(
    dataset=fashion_ds['train'],
    instance_list=train_instances,
    coco_api=coco
)

val_dataset = FashionSAMDataset(   
    dataset=fashion_ds['validation'],
    instance_list=val_instances,
    coco_api=coco
)


In [22]:
BATCH_SIZE = 2

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=sam_collate_fn,
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=sam_collate_fn,
)

In [23]:
first_batch = next(iter(train_dataloader))
print(f"Image batch shape:{first_batch['image'].shape}")
print(f"Mask batch shape:{first_batch['mask'].shape}")
print(f"Points batch length:{len(first_batch['points'])}")
print(f"Point labels batch length:{len(first_batch['point_labels'])}")

Image batch shape:torch.Size([2, 3, 1024, 1024])
Mask batch shape:torch.Size([2, 1, 1024, 1024])
Points batch length:2
Point labels batch length:2
