In [4]:
import os
import random
import numpy as np
from PIL import Image
from torchvision import transforms

In [5]:
# Define class index
CLASS_INDEX = {
    'hat_on': 0,
    'hat_off': 1,
    'clothes_on': 2,
    'clothes_off': 3,
    'shoes_on': 4,
    'shoes_off': 5,
    'mask_on': 6,
    'mask_off': 7
}

In [6]:
def rand_bbox(size, lam):
    W = size[2]
    H = size[1]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    
    return bbx1, bby1, bbx2, bby2

In [7]:
def update_labels(labels_a, labels_b, bbx1, bby1, bbx2, bby2, image_size):
    new_labels = []

    for label in labels_a:
        class_index, x_center, y_center, width, height = label
        x_center *= image_size[2]
        y_center *= image_size[1]
        width *= image_size[2]
        height *= image_size[1]

        if bbx1 <= x_center <= bbx2 and bby1 <= y_center <= bby2:
            continue # Skip covered labels

        new_x_center = x_center / image_size[2]
        new_y_center = y_center / image_size[1]
        new_width = width / image_size[2]
        new_height = height / image_size[1]
        new_labels.append([class_index, new_x_center, new_y_center, new_width, new_height])
    
    for label in labels_b:
        class_index, x_center, y_center, width, height = label
        x_center *= image_size[2]
        y_center *= image_size[1]
        width *= image_size[2]
        height *= image_size[1]

        if not (bbx1 <= x_center <= bbx2 and bby1 <= y_center <= bby2):
            continue # Skip uncovered labels

        new_x_center = (x_center - bbx1) / image_size[2]
        new_y_center = (y_center - bby1) / image_size[1]
        new_width = width / image_size[2]
        new_height = height / image_size[1]
        new_labels.append([class_index, new_x_center, new_y_center, new_width, new_height])

    return new_labels

In [8]:
# Function to implement CutMix data augmentation
def cutmix_data(image_paths, labels_dict, output_folder, count, alpha=1.0):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for i in range(count):
        # Randomly select two images
        image_path_a, image_path_b = random.sample(image_paths, 2)

        # Read images and labels
        img_a = Image.open(image_path_a).convert('RGB')
        img_b = Image.open(image_path_b).convert('RGB')
        labels_a = labels_dict[image_path_a]
        labels_b = labels_dict[image_path_b]

        # Convert images to tensors
        tensor_transform = transforms.ToTensor()
        img_tensor_a = tensor_transform(img_a)
        img_tensor_b = tensor_transform(img_b)

        # Generate a random value from Beta distribution
        lam = np.random.beta(alpha, alpha)

        # Calculate the bounding box
        bbx1, bby1, bbx2, bby2 = rand_bbox(img_tensor_a.size(), lam)

        mixed_img_tensor = img_tensor_a.clone()
        mixed_img_tensor[:, bby1:bby2, bbx1:bbx2] = img_tensor_b[:, bby1:bby2, bbx1:bbx2].clone()

        # Update labels
        mixed_labels = update_labels(labels_a, labels_b, bbx1, bby1, bbx2, bby2, img_tensor_a.size())

        # Save the mixed image
        output_image_path = os.path.join(output_folder, f"{os.path.basename(image_path_a).replace('.jpg', '')}_{i}.jpg")
        mixed_img = transforms.ToPILImage()(mixed_img_tensor)
        mixed_img.save(output_image_path)

        # Save the corresponding labels
        mixed_label_path = os.path.join(output_folder, f"{os.path.basename(image_path_a).replace('.jpg', '')}_{i}.txt")
        with open(mixed_label_path, 'w') as f:
            for label in mixed_labels:
                label_str = ' '.join([str(elem) for elem in label])
                f.write(label_str + '\n')

        print(f"Processed and saved: {output_image_path}")

In [12]:
# Function to load images and labels
def load_images_and_labels(folder):
    image_paths = []
    labels_dict = {}
    for file in os.listdir(folder):
        if file.endswith('jpg'):
            image_path = os.path.join(folder, file)
            image_paths.append(image_path)
            label_path = image_path.replace('.jpg', '.txt')
            with open(label_path, 'r') as f:
                labels = []
                for line in f:
                    parts = line.strip().split()
                    labels.append([int(parts[0]), float(parts[1]), float(parts[2]), float(parts[3]), float(parts[4])])
                labels_dict[image_path] = labels
    return image_paths, labels_dict

In [13]:
# set paths
input_folder = 'raw_dataset_2'
output_folder = "rawdata_cutmix"

# Load images and labels
image_paths, labels_dict = load_images_and_labels(input_folder)

# Total number of images to generate
total_images_to_generate = 100

# Ensure there are at least two images for CutMix
if len(image_paths) < 2:
    raise ValueError("At least two images are required for CutMix.")

# Generate CutMix images
cutmix_data(image_paths, labels_dict, output_folder, total_images_to_generate)

Processed and saved: rawdata_cutmix\frame_90420_0.jpg
Processed and saved: rawdata_cutmix\frame_6420_1.jpg
Processed and saved: rawdata_cutmix\frame_91980_2.jpg
Processed and saved: rawdata_cutmix\frame_55680_3.jpg
Processed and saved: rawdata_cutmix\frame_115020_4.jpg
Processed and saved: rawdata_cutmix\frame_89040_5.jpg
Processed and saved: rawdata_cutmix\frame_82800_6.jpg
Processed and saved: rawdata_cutmix\frame_127440_7.jpg
Processed and saved: rawdata_cutmix\frame_82500_8.jpg
Processed and saved: rawdata_cutmix\frame_84900_9.jpg
Processed and saved: rawdata_cutmix\frame_56460_10.jpg
Processed and saved: rawdata_cutmix\frame_86220_11.jpg
Processed and saved: rawdata_cutmix\frame_59460_12.jpg
Processed and saved: rawdata_cutmix\frame_123300_13.jpg
Processed and saved: rawdata_cutmix\frame_1020_14.jpg
Processed and saved: rawdata_cutmix\frame_159900_15.jpg
Processed and saved: rawdata_cutmix\frame_96540_16.jpg
Processed and saved: rawdata_cutmix\frame_42060_17.jpg
Processed and save