In [1]:
import utils
import numpy as np

In [2]:
train_images, train_labels, test_images, test_labels = utils.load_processed_data('preprocessed_data.pkl')

In [3]:
def augment_data_for_balance(images, labels, target_counts=None):
    """
    Augment data to correct class imbalance by applying horizontal flips.
    If target_counts is an integer, classes with more samples will be downsampled.
    
    Parameters:
    images: Array of images (n_samples, height, width, channels)
    labels: Array of labels (n_samples,)
    target_counts: Either:
                  - Dict mapping class labels to target counts, or
                  - Integer specifying the same target count for all classes
                  If None, will use the max count of any class for all classes
    
    Returns:
    tuple: (augmented_images, augmented_labels)
    """
    # Count examples per class
    unique_labels, counts = np.unique(labels, return_counts=True)
    class_counts = dict(zip(unique_labels, counts))
    
    # Determine target counts
    if target_counts is None:
        # Use the max count as default for all classes
        max_count = max(counts)
        target_counts = {label: max_count for label in unique_labels}
    elif isinstance(target_counts, (int, float)):
        # If a single number is provided, use it for all classes
        target_count_value = int(target_counts)
        target_counts = {label: target_count_value for label in unique_labels}
    
    # Initialize lists for augmented data
    augmented_images = []
    augmented_labels = []
    
    # Process each class
    for label in unique_labels:
        # Get indices for this class
        class_indices = np.where(labels == label)[0]
        current_count = len(class_indices)
        target_count = target_counts[label]
        
        if current_count > target_count:
            # Need to downsample
            selected_indices = np.random.choice(class_indices, target_count, replace=False)
            augmented_images.extend(images[selected_indices])
            augmented_labels.extend([label] * len(selected_indices))
        
        elif current_count < target_count:
            # Need to upsample with augmentation
            # First add all original samples
            augmented_images.extend(images[class_indices])
            augmented_labels.extend([label] * current_count)
            
            # Determine how many to augment (limited by available originals)
            n_to_augment = min(target_count - current_count, current_count)
            
            if n_to_augment > 0:
                # Select indices to augment (without replacement to avoid duplicates)
                indices_to_augment = np.random.choice(class_indices, n_to_augment, replace=False)
                
                # Apply horizontal flipping
                for idx in indices_to_augment:
                    flipped_image = np.fliplr(images[idx]).copy()
                    augmented_images.append(flipped_image)
                    augmented_labels.append(label)
        
        else:
            # Already at target count, keep all samples
            augmented_images.extend(images[class_indices])
            augmented_labels.extend([label] * current_count)
    
    # Convert to numpy arrays
    augmented_images = np.array(augmented_images)
    augmented_labels = np.array(augmented_labels)
    
    # Shuffle the augmented dataset
    np.random.seed(0)
    indices = np.arange(len(augmented_labels))
    shuffled_indices = np.random.permutation(indices)
    augmented_images = augmented_images[shuffled_indices]
    augmented_labels = augmented_labels[shuffled_indices]
    
    return augmented_images, augmented_labels


def print_class_distribution(labels, class_names=None):
    """
    Print the distribution of classes in a dataset.
    
    Parameters:
    labels: Array of labels
    class_names: Optional dictionary mapping class IDs to names
    """
    unique_labels, counts = np.unique(labels, return_counts=True)
    total = len(labels)
    
    print("Class distribution:")
    for i, (label, count) in enumerate(zip(unique_labels, counts)):
        if class_names is not None and label in class_names:
            class_name = class_names[label]
        else:
            class_name = f"Class {label}"
        
        percentage = (count / total) * 100
        print(f"{class_name}: {count} examples ({percentage:.1f}%)")

In [4]:
train_images_aug, train_labels_aug = augment_data_for_balance(train_images, train_labels, target_counts=560)

In [5]:
print_class_distribution(train_labels_aug)

Class distribution:
Class 1: 560 examples (14.3%)
Class 2: 560 examples (14.3%)
Class 3: 560 examples (14.3%)
Class 4: 560 examples (14.3%)
Class 5: 560 examples (14.3%)
Class 6: 560 examples (14.3%)
Class 7: 560 examples (14.3%)


In [6]:
# pickle augmented image arrays
utils.save_processed_data('augmented_preprocessed_data.pkl', train_images_aug, train_labels_aug, test_images, test_labels)

Data saved to ./281_final_project_data/augmented_preprocessed_data.pkl
