In [2]:
import os
import warnings
from collections import Counter

import numpy as np
import pandas as pd
import torch
from PIL import Image
from sklearn.model_selection import StratifiedShuffleSplit
from torchvision import transforms
from tqdm import tqdm

warnings.filterwarnings("ignore")

In [3]:
# === 1. Path setup ===
path = "C:/Users/jono_/.cache/kagglehub/datasets/a2015003713/militaryaircraftdetectiondataset/versions/86/crop"

# === 2. Collect image paths and labels ===
images = []
labels = []

for folder in os.listdir(path):
    folder_path = os.path.join(path, folder)
    if os.path.isdir(folder_path):
        for image in os.listdir(folder_path):
            if image.endswith(".jpg"):
                images.append(os.path.join(folder_path, image))
                labels.append(folder)

# Convert to arrays for indexing
images = np.array(images)
labels = np.array(labels)

# Save to CSV for reference
pd.DataFrame({"images": images, "labels": labels}).to_csv(
    "military_aircraft_crop.csv", index=False
)

# === 3. Define Transforms ===
mean = [0.4913, 0.5240, 0.5560]
std = [0.1958, 0.1944, 0.1987]

transform_train = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)

transform_eval = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
    ]
)

# === 4. Stratified Train/Val/Test Split ===
sss_1 = StratifiedShuffleSplit(n_splits=1, test_size=0.30, random_state=42)
train_idx, temp_idx = next(sss_1.split(images, labels))

sss_2 = StratifiedShuffleSplit(n_splits=1, test_size=0.5, random_state=42)
val_idx, test_idx = next(sss_2.split(images[temp_idx], labels[temp_idx]))

val_idx = temp_idx[val_idx]
test_idx = temp_idx[test_idx]

# === 5. Sanity check: ensure classes are covered in all splits ===
def check_class_coverage(idxs, labels, name):
    subset_labels = labels[idxs]
    class_counts = Counter(subset_labels)
    print(f"\n{name} set class count: {len(class_counts)} / {len(set(labels))} classes")
    rare = [cls for cls, count in class_counts.items() if count < 3]
    if rare:
        print(f"⚠️ Rare classes with <3 samples in {name}: {rare}")

check_class_coverage(train_idx, labels, "Train")
check_class_coverage(val_idx, labels, "Val")
check_class_coverage(test_idx, labels, "Test")

# === 6. Helper to apply transform and save ===
def transform_and_save(name, idxs, transform):
    transformed_images = []
    transformed_labels = []

    for i in tqdm(idxs, desc=f"Processing {name}"):
        try:
            img_path = images[i]
            label = labels[i]
            image = Image.open(img_path).convert("RGB")
            image_tensor = transform(image)
            transformed_images.append(image_tensor)
            transformed_labels.append(label)
        except Exception as e:
            print(f"❌ Skipping {images[i]} due to error: {e}")

    # Stack image tensors into one big tensor [N, 3, 224, 224]
    image_tensor_stack = torch.stack(transformed_images)

    # Save as dict
    torch.save(
        {"tensors": image_tensor_stack, "labels": transformed_labels},
        f"military_aircraft_crop_{name}.pt",
    )

# === 7. Process and save each split ===
transform_and_save("train", train_idx, transform_train)
transform_and_save("val", val_idx, transform_eval)
transform_and_save("test", test_idx, transform_eval)

print("✅ All splits with stacked tensors saved successfully!")


Train set class count: 81 / 81 classes

Val set class count: 81 / 81 classes

Test set class count: 81 / 81 classes


Processing train: 100%|██████████| 24160/24160 [03:21<00:00, 120.05it/s]
Processing val: 100%|██████████| 5177/5177 [01:57<00:00, 44.07it/s] 
Processing test: 100%|██████████| 5178/5178 [01:48<00:00, 47.76it/s] 


✅ All splits with stacked tensors saved successfully!
