In [None]:
# Cell 1: Import libraries
import os
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as T
import random
import re

In [None]:
# Cell 2: Settings
INPUT_FOLDER = "path/to/your/images"  # CHANGE THIS
OUTPUT_FOLDER = "output_balanced"     # CHANGE THIS
IMAGE_SIZE = 224
TARGET_COUNT = 30  # Target images per class
SEED = 42

random.seed(SEED)
np.random.seed(SEED)

In [None]:
# Cell 3: Create folders
os.makedirs(f"{OUTPUT_FOLDER}/augmented", exist_ok=True)
os.makedirs(f"{OUTPUT_FOLDER}/resized", exist_ok=True)
os.makedirs(f"{OUTPUT_FOLDER}/final", exist_ok=True)

In [None]:
# Cell 4: Basic functions
def get_class_label(filename):
    """Get class from filename (e.g., '0p', '5p', '100p')"""
    match = re.search(r'(\d+)[pP]', filename)
    if match:
        return f"{match.group(1)}p"
    return "unknown"

def make_square(image, size):
    """Resize image and pad to square"""
    # Calculate new size
    w, h = image.size
    scale = size / max(w, h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    
    # Resize
    resized = image.resize((new_w, new_h), Image.BILINEAR)
    
    # Create square with background color
    bg_color = int(np.mean(np.array(resized)))
    square = Image.new('L', (size, size), bg_color)
    
    # Paste in center
    x = (size - new_w) // 2
    y = (size - new_h) // 2
    square.paste(resized, (x, y))
    
    return square

In [None]:
# Cell 5: Define augmentations
def aug1(img): return img.rotate(90, expand=True)     # 90° rotation
def aug2(img): return img.rotate(180, expand=True)    # 180° rotation  
def aug3(img): return img.rotate(270, expand=True)    # 270° rotation
def aug4(img): return img.transpose(Image.FLIP_LEFT_RIGHT)
def aug5(img): return img.transpose(Image.FLIP_TOP_BOTTOM)

augmentation_list = [aug1, aug2, aug3, aug4, aug5]

In [None]:
# Cell 6: Group images by class
image_files = [f for f in os.listdir(INPUT_FOLDER) 
               if f.lower().endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp'))]

# Group by class
classes = {}
for filename in image_files:
    label = get_class_label(filename)
    if label not in classes:
        classes[label] = []
    classes[label].append(filename)

# Show counts
print("Images per class:")
for label, images in sorted(classes.items()):
    print(f"  {label}: {len(images)} images")

In [None]:
# Cell 7: Copy originals and create augmentations
print("\nProcessing images...")

for class_label, filenames in sorted(classes.items()):
    print(f"\nClass {class_label}:")
    
    # Copy all originals first
    for filename in filenames:
        img = Image.open(os.path.join(INPUT_FOLDER, filename))
        if img.mode != 'L':
            img = img.convert('L')
        
        base = os.path.splitext(filename)[0]
        img.save(f"{OUTPUT_FOLDER}/augmented/{base}_original.png")
    
    # Create augmentations if needed
    current_count = len(filenames)
    if current_count < TARGET_COUNT:
        needed = TARGET_COUNT - current_count
        print(f"  Creating {needed} augmentations...")
        
        for i in range(needed):
            # Pick random original
            source_file = random.choice(filenames)
            img = Image.open(os.path.join(INPUT_FOLDER, source_file))
            if img.mode != 'L':
                img = img.convert('L')
            
            # Apply 3-4 random augmentations
            num_augs = random.randint(3, 4)
            selected_augs = random.sample(augmentation_list, num_augs)
            
            augmented = img
            for aug_func in selected_augs:
                augmented = aug_func(augmented)
            
            # Save
            base = os.path.splitext(source_file)[0]
            augmented.save(f"{OUTPUT_FOLDER}/augmented/{base}_aug{i}.png")

In [None]:
# Cell 8: Resize all images
print("\nResizing all images...")
augmented_files = [f for f in os.listdir(f"{OUTPUT_FOLDER}/augmented") 
                   if f.endswith('.png')]

for i, filename in enumerate(augmented_files):
    if i % 50 == 0:
        print(f"  {i}/{len(augmented_files)}")
    
    img = Image.open(f"{OUTPUT_FOLDER}/augmented/{filename}")
    resized = make_square(img, IMAGE_SIZE)
    resized.save(f"{OUTPUT_FOLDER}/resized/{filename}")

In [None]:
# Cell 9: Normalize all images
print("\nNormalizing images...")
resized_files = [f for f in os.listdir(f"{OUTPUT_FOLDER}/resized") 
                 if f.endswith('.png')]

for i, filename in enumerate(resized_files):
    if i % 50 == 0:
        print(f"  {i}/{len(resized_files)}")
    
    # Load and convert to RGB
    img = Image.open(f"{OUTPUT_FOLDER}/resized/{filename}")
    img_rgb = img.convert('RGB')
    
    # Normalize
    transform = T.Compose([
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    tensor = transform(img_rgb)
    
    # Save tensor
    np.save(f"{OUTPUT_FOLDER}/final/{filename.replace('.png', '.npy')}", tensor.numpy())
    
    # Save image for viewing
    img.save(f"{OUTPUT_FOLDER}/final/{filename}")

In [None]:
# Cell 10: Check final counts
print("\nFinal image count per class:")
final_counts = {}
for filename in os.listdir(f"{OUTPUT_FOLDER}/final"):
    if filename.endswith('.png'):
        label = get_class_label(filename)
        final_counts[label] = final_counts.get(label, 0) + 1

for label, count in sorted(final_counts.items()):
    print(f"  {label}: {count} images")

In [None]:
# Cell 11: Simple data loader
def load_data():
    """Load all processed data"""
    data = []
    labels = []
    
    for filename in os.listdir(f"{OUTPUT_FOLDER}/final"):
        if filename.endswith('.npy'):
            # Load tensor
            tensor = np.load(f"{OUTPUT_FOLDER}/final/{filename}")
            data.append(torch.from_numpy(tensor))
            
            # Get label
            label = get_class_label(filename)
            labels.append(label)
    
    return data, labels

# Usage
print("\nTo load data:")
print("data, labels = load_data()")
print(f"Total tensors ready: {len([f for f in os.listdir(f'{OUTPUT_FOLDER}/final') if f.endswith('.npy')])}")