In [1]:
import json
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt
import uuid

In [2]:
image_path = '/Users/gauravneupane/Documents/ml/data/datasets/validation/image'
annotation_path = '/Users/gauravneupane/Documents/ml/data/datasets/validation/annos'

In [3]:
exclude_keys = ['source','pair_id']


In [4]:
def load_json(json_path):
    with open(json_path, 'r') as file:
        data = json.load(file)
    return data

In [5]:
def build_output_image(image_shape,json_data):
    '''extract mask information as points from the dataset and create a new image as a mask'''
    image_segmentation_data_points = []
    filtered_keys = [key for key in json_data.keys() if key not in exclude_keys]
    categories = []
    for key in filtered_keys:
        data_points = json_data[key]['segmentation']
        category = json_data[key]['category_name']
        class_label = json_data[key]['category_id']
        segmentation_with_class = {'data_points': data_points,
                                   'category': category,
                                   'class_label': class_label}
        image_segmentation_data_points.append(segmentation_with_class)
        categories.append(category)
    # mask = create_mask_for_whole_image(image_shape,image_segmentation_data_points)
    return categories, image_segmentation_data_points

In [39]:
images_per_class = 100
output_dir = '/Users/gauravneupane/Documents/ml/projects/image_retrieval/new_classification'
output_image_shape = (224,224)
training_category_count = {'trousers': 0,
                    'short sleeve top': 0,
                    'long sleeve dress': 0,
                    'long sleeve top': 0,
                    'skirt': 0,
                    'shorts': 0,
                    'long sleeve outwear': 0,
                    'vest dress': 0,
                    'short sleeve dress': 0,
                    'vest': 0,
                    'sling dress': 0,
                    'short sleeve outwear': 0,
                    'sling': 0
                    }
validation_category_count = {'trousers': 0,
                    'short sleeve top': 0,
                    'long sleeve dress': 0,
                    'long sleeve top': 0,
                    'skirt': 0,
                    'shorts': 0,
                    'long sleeve outwear': 0,
                    'vest dress': 0,
                    'short sleeve dress': 0,
                    'vest': 0,
                    'sling dress': 0,
                    'short sleeve outwear': 0,
                    'sling': 0
                    }
train_images = os.listdir(image_path)[:10000]
# val_images = os.listdir(image_path)[100001:150001]

train_path = os.path.join(output_dir, 'test')
extract_save_image(train_images, train_path, training_category_count, 100)
# # print(training_category_count)
# validation_path = os.path.join(output_dir, 'validation')
# extract_save_image(val_images, validation_path, validation_category_count, 200)


In [37]:
def extract_save_image(images_path, save_dir, categories, max_data):
    stop = True
    for i, file in enumerate(images_path):
        full_path = os.path.join(image_path, file)
        annot_path = os.path.join(annotation_path, file.split('.')[0]+'.json')
        annot_data = load_json(annot_path)
        image = cv2.imread(full_path)
        cats, data_points = build_output_image(output_image_shape,annot_data)
        for cat in cats:
            if categories[cat] < max_data:
                categories[cat] +=1 
                create_multiple_images_from_segmentation(image, data_points, output_dir = save_dir)
        

In [9]:
def create_multiple_images_from_segmentation(original_image, segmentation_cats, output_size=(224, 224), output_dir='output'):
    '''Create separate images for each segmentation from the original image and save them individually'''
    
    for idx, data in enumerate(segmentation_cats):
        segmentation = data['data_points']
        category = data['category']
        save_path = os.path.join(output_dir, category)
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        # Create a blank mask with the same size as the original image
        mask = np.zeros(original_image.shape[:2], dtype=np.uint8)
        
        # Draw the current segmentation on the mask
        all_points = []
        for seg in segmentation:
            points = np.array(seg).reshape(-1, 2)
            all_points.extend(points)
            cv2.fillPoly(mask, [points.astype(np.int32)], 255)  # Use 255 to fill the mask
        
        # Convert all_points to a NumPy array
        all_points = np.array(all_points)
        
        if all_points.size == 0:
            continue  # Skip if no valid points
        
        # Determine the bounding box for the segmentation points
        min_x, min_y = np.min(all_points, axis=0)
        max_x, max_y = np.max(all_points, axis=0)
        
        # Extract the region of interest (ROI) using the mask
        roi = cv2.bitwise_and(original_image, original_image, mask=mask)
        
        # Crop the ROI to the bounding box
        cropped_roi = roi[int(min_y):int(max_y), int(min_x):int(max_x)]
        
        # Resize the cropped ROI to the desired output size
        if cropped_roi.size > 0:  # Check if the cropped ROI is not empty
            resized_image = cv2.resize(cropped_roi, output_size, interpolation=cv2.INTER_LINEAR)
            
            # Save the segmented image
            output_path = os.path.join(save_path, f'{uuid.uuid4()}.jpg')
            cv2.imwrite(output_path, resized_image)

In [50]:
training_data = '/Users/gauravneupane/Documents/ml/projects/image_retrieval/new_classification/training'
validation_data = '/Users/gauravneupane/Documents/ml/projects/image_retrieval/new_classification/validation'

for folder_name in os.listdir(validation_data):
    files_path = os.path.join(validation_data, folder_name)
    try:
        for image_file in os.listdir(files_path)[200:]:
            path = os.path.join(files_path, image_file)
            os.remove(path)
    except:
        print(f"error in {files_path}")
    # print(len(os.listdir(files[1000:])))
    # for file in files[1000:]:



error in /Users/gauravneupane/Documents/ml/projects/image_retrieval/new_classification/validation/.DS_Store
