# Dataset setup

In [1]:
# Libraries


import os
import random
import numpy as np
from PIL import Image, ImageOps, ImageEnhance
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split


In [2]:
# === CONFIG ===
random.seed(42)
np.random.seed(42)
processed_dir = "c:/Users/user/Documents/Real_time_weapon_detection/processed_data"
img_size = (224, 224)  # MobileNet expected size
base_augment_factor = 2

# Threat mapping
threat_map = {
    "human_with_weapon": "THREAT",
    "weapon_only": "THREAT",
    "human_only": "NO_THREAT",
    "no_threat": "NO_THREAT"
}

# === AUGMENTATION FUNCTION (RGB) ===
def augment_image(img):
    """Apply random augmentations to an RGB image."""
    if random.random() > 0.5:
        img = ImageOps.mirror(img)  # Flip

    angle = random.uniform(-15, 15)
    img = img.rotate(angle)

    # Brightness & contrast adjustments
    enhancer = ImageEnhance.Brightness(img)
    img = enhancer.enhance(random.uniform(0.8, 1.2))

    enhancer = ImageEnhance.Contrast(img)
    img = enhancer.enhance(random.uniform(0.8, 1.2))

    return img

# === STEP 1: COUNT FILES PER CLASS ===
class_counts = {"THREAT": 0, "NO_THREAT": 0}
class_files = {"THREAT": [], "NO_THREAT": []}

for label in os.listdir(processed_dir):
    class_dir = os.path.join(processed_dir, label)
    if not os.path.isdir(class_dir):
        continue
    mapped_label = threat_map[label]
    files = [f for f in os.listdir(class_dir) if f.lower().endswith((".jpg", ".jpeg", ".png"))]
    class_counts[mapped_label] += len(files)
    class_files[mapped_label].extend([(os.path.join(class_dir, f), mapped_label) for f in files])

print(f"[i] Class counts before augmentation: {class_counts}")

# === STEP 2: BALANCE DATASET WITH AUGMENTATION ===
max_count = max(class_counts.values())
images, labels = [], []

for mapped_label, files in class_files.items():
    num_original = len(files)
    if num_original < max_count:
        extra_needed = max_count - num_original
        augment_factor = max(base_augment_factor, int(np.ceil(extra_needed / num_original)))
    else:
        augment_factor = base_augment_factor

    print(f"[i] Processing class '{mapped_label}' with augment_factor={augment_factor}")

    for img_path, lbl in files:
        try:
            img = Image.open(img_path).convert("RGB").resize(img_size)

            # Add original (keep 0–255 values for MobileNet preprocessing later)
            images.append(np.array(img, dtype=np.float32))
            labels.append(lbl)

            # Augment
            for _ in range(augment_factor):
                aug_img = augment_image(img)
                images.append(np.array(aug_img, dtype=np.float32))
                labels.append(lbl)

        except Exception as e:
            print(f"[✗] Could not process {img_path}: {e}")

# === STEP 3: ENCODE LABELS ===
le = LabelEncoder()
labels_encoded = le.fit_transform(labels)  # THREAT=1, NO_THREAT=0 (depending on fit order)

# === STEP 4: TO NUMPY ARRAYS ===
X = np.array(images)  # Shape: (N, 224, 224, 3)
y = np.array(labels_encoded)

print(f"[✓] After augmentation, class distribution: {dict(zip(le.classes_, np.bincount(y)))}")

# === STEP 5: SPLIT ===
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, stratify=y, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp, random_state=42)

# === STEP 6: SAVE TO DISK ===
np.savez(
    "dataset_rgb_224_binary_balanced.npz",
    X_train=X_train, y_train=y_train,
    X_val=X_val, y_val=y_val,
    X_test=X_test, y_test=y_test,
    label_names=le.classes_
)

print(f"[✓] Balanced MobileNet-ready dataset saved.")
print(f"[✓] Train: {X_train.shape}, Val: {X_val.shape}, Test: {X_test.shape}")


[i] Class counts before augmentation: {'THREAT': 698, 'NO_THREAT': 493}
[i] Processing class 'THREAT' with augment_factor=2
[i] Processing class 'NO_THREAT' with augment_factor=2
[✓] After augmentation, class distribution: {'NO_THREAT': 1479, 'THREAT': 2094}
[✓] Balanced MobileNet-ready dataset saved.
[✓] Train: (2501, 224, 224, 3), Val: (536, 224, 224, 3), Test: (536, 224, 224, 3)
