In [4]:
import os
import shutil
import pandas as pd
import numpy as np
from PIL import Image
from tqdm import tqdm

class ImageDataset:
    def __init__(self, csv, directory, target_size=(380, 380)):
        self.df = pd.read_csv(csv).values
        self.directory = directory
        self.target_size = target_size

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        name, target = self.df[idx]
        img_path = os.path.join(self.directory, name)
        
        if not os.path.exists(img_path):
            raise FileNotFoundError(f"Image {name} not found in the directory {self.directory}")
        
        img = Image.open(img_path).resize(self.target_size)
        img = np.array(img)
        return img, target, name

class PatchAugmentation:
    def __init__(self, max_patch_size=(32, 32)):
        self.max_patch_size = max_patch_size

    def __call__(self, img1, label1, img2, label2):
        h, w, _ = img1.shape
        ph, pw = min(self.max_patch_size[0], h), min(self.max_patch_size[1], w)
        
        x, y = np.random.randint(0, w - pw + 1), np.random.randint(0, h - ph + 1)
        
        img1[y:y + ph, x:x + pw] = img2[y:y + ph, x:x + pw]
        combined_label = label1  # Keep the label of the base image
        return img1, combined_label

# Paths
csv_path = 'dj/data/train.csv'
image_dir = 'dj/data/train/'
output_dir = 'dj/data/train_augmented/'

# Create output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Load dataset
dataset = ImageDataset(csv_path, image_dir)
patch_augmenter = PatchAugmentation(max_patch_size=(32, 32))

# New CSV file
new_csv_data = []

# Copy original images and update new CSV
for i in tqdm(range(len(dataset))):
    img, label, name = dataset[i]
    
    # Save original image to augmented directory
    original_img_path = os.path.join(image_dir, name)
    new_img_path = os.path.join(output_dir, name)
    shutil.copy(original_img_path, new_img_path)
    
    # Append to new CSV data
    new_csv_data.append([name, label])

# Augmentation loop
for i in tqdm(range(len(dataset))):
    img1, label1, name1 = dataset[i]
    
    # Select a random second image
    idx2 = np.random.randint(0, len(dataset))
    img2, label2, name2 = dataset[idx2]
    
    # Apply patch augmentation
    augmented_img, augmented_label = patch_augmenter(img1, label1, img2, label2)
    
    # Convert augmented image back to PIL Image
    augmented_img_pil = Image.fromarray(augmented_img.astype(np.uint8))
    
    # Save augmented image
    new_name = f"patch_{name1}"
    augmented_img_pil.save(os.path.join(output_dir, new_name))
    
    # Append to new CSV data
    new_csv_data.append([new_name, augmented_label])

# Save new CSV file
new_csv_df = pd.DataFrame(new_csv_data, columns=['ID', 'target'])
new_csv_df.to_csv('dj/data/train_augmented.csv', index=False)

print("Augmentation complete and CSV file saved.")


100%|██████████| 1570/1570 [00:07<00:00, 218.37it/s]
100%|██████████| 1570/1570 [00:14<00:00, 110.41it/s]

Augmentation complete and CSV file saved.



