In [19]:
import utils
import numpy as np
from sklearn.model_selection import train_test_split

In [20]:
train_images, train_labels, test_images, test_labels = utils.load_processed_data('preprocessed_data.pkl')

train_images, val_images, train_labels, val_labels = train_test_split(train_images, train_labels, test_size=0.25, train_size=0.75, random_state=1234, shuffle=True)

In [21]:
def augment_data_for_balance(images, labels, target_counts=None):
    """
    Augment data to correct class imbalance by applying horizontal flips once per original.
    If target_counts is an integer, it is used as the target count for all classes.
    Classes with counts above the target are left unchanged.
    
    Parameters:
    images: np.ndarray, shape (n_samples, H, W, C)
    labels: np.ndarray, shape (n_samples,)
    target_counts: dict or int or None - target number of samples per class
    
    Returns:
    (augmented_images, augmented_labels): tuple of np.ndarrays
    """
    # Count existing examples per class
    unique_labels, counts = np.unique(labels, return_counts=True)
    class_counts = dict(zip(unique_labels, counts))
    
    # Determine target counts per class
    if target_counts is None:
        max_count = max(counts)
        target_counts = {label: max_count for label in unique_labels}
    elif isinstance(target_counts, (int, float)):
        target_count_value = int(target_counts)
        target_counts = {label: target_count_value for label in unique_labels}
    
    augmented_images = []
    augmented_labels = []
    
    for label in unique_labels:
        class_indices = np.where(labels == label)[0]
        current_count = len(class_indices)
        target_count = target_counts[label]
        
        # Always keep all original images
        augmented_images.extend(images[class_indices])
        augmented_labels.extend([label] * current_count)
        
        # Only augment if below target
        if current_count < target_count:
            # Calculate how many augmented samples are needed
            needed = target_count - current_count
            
            # Limit augmentations to number of original images to avoid flipping more than once
            n_to_augment = min(needed, current_count)
            if n_to_augment > 0:
                indices_to_augment = np.random.choice(class_indices, n_to_augment, replace=False)
                for idx in indices_to_augment:
                    flipped_image = np.fliplr(images[idx]).copy()
                    augmented_images.append(flipped_image)
                    augmented_labels.append(label)
    
    # Convert to arrays and shuffle
    augmented_images = np.array(augmented_images)
    augmented_labels = np.array(augmented_labels)
    
    np.random.seed(0)
    indices = np.arange(len(augmented_labels))
    shuffled = np.random.permutation(indices)
    
    return augmented_images[shuffled], augmented_labels[shuffled]




def print_class_distribution(labels, class_names=None):
    """
    Print the distribution of classes in a dataset.
    
    Parameters:
    labels: Array of labels
    class_names: Optional dictionary mapping class IDs to names
    """
    unique_labels, counts = np.unique(labels, return_counts=True)
    total = len(labels)
    
    print("Class distribution:")
    for i, (label, count) in enumerate(zip(unique_labels, counts)):
        if class_names is not None and label in class_names:
            class_name = class_names[label]
        else:
            class_name = f"Class {label}"
        
        percentage = (count / total) * 100
        print(f"{class_name}: {count} examples ({percentage:.1f}%)")

In [22]:
train_images_aug, train_labels_aug = augment_data_for_balance(train_images, train_labels, target_counts=2000)

In [23]:
print_class_distribution(train_labels_aug)

Class distribution:
Class 1: 1878 examples (15.6%)
Class 2: 416 examples (3.5%)
Class 3: 1070 examples (8.9%)
Class 4: 3624 examples (30.2%)
Class 5: 2000 examples (16.7%)
Class 6: 1024 examples (8.5%)
Class 7: 2000 examples (16.7%)


In [24]:
# pickle augmented image arrays
utils.save_processed_data('all_augmented_preprocessed_data.pkl', train_images_aug, train_labels_aug, val_images, val_labels, test_images, test_labels)

Data saved to ./281_final_project_data/all_augmented_preprocessed_data.pkl
