In [None]:
import os, shutil, cv2, hashlib
from pathlib import Path
from PIL import Image
from tqdm import tqdm
import kagglehub
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.metrics import classification_report, confusion_matrix

# -----------------------------
# Paths setup
PROJECT_ROOT = Path(".")
DATA_DIR = PROJECT_ROOT / "data" / "leapgestrecog"
RAW_DIR = DATA_DIR / "raw"
CLEAN_DIR = DATA_DIR / "clean"
RAW_DIR.mkdir(parents=True, exist_ok=True)

IMG_SIZE = (224, 224)
BATCH_SIZE = 32
VAL_SPLIT = 0.2
SEED = 42

# -----------------------------
# Download dataset
print("Downloading dataset...")
path = kagglehub.dataset_download("gti-upm/leapgestrecog")
print("Downloaded to:", path)

if os.path.isdir(os.path.join(path, "leapGestRecog")):
    src = os.path.join(path, "leapGestRecog")
else:
    src = path

print("Using source:", src)

# Copy dataset
if not any(RAW_DIR.iterdir()):
    shutil.copytree(src, RAW_DIR, dirs_exist_ok=True)
    print("Raw data copied")
else:
    print("Raw data already exists")

if not CLEAN_DIR.exists():
    shutil.copytree(RAW_DIR, CLEAN_DIR, dirs_exist_ok=True)
    print("Clean data folder created")
else:
    print("Clean data already exists")

# -----------------------------
# Cleaning functions
def remove_corrupted(folder: Path):
    print("\nRemoving corrupted images...")
    for cls in tqdm(sorted(os.listdir(folder))):
        cls_path = folder / cls
        if not cls_path.is_dir():
            continue
        for file in os.listdir(cls_path):
            img_path = cls_path / file
            if img_path.is_dir():
                continue
            try:
                img = cv2.imread(str(img_path))
                if img is None and img_path.is_file():
                    img_path.unlink(missing_ok=True)
            except:
                if img_path.is_file():
                    img_path.unlink(missing_ok=True)

def fix_format(folder: Path, size=(224,224)):
    print("\nStandardizing images...")
    for cls in tqdm(sorted(os.listdir(folder))):
        cls_path = folder / cls
        if not cls_path.is_dir():
            continue
        for file in os.listdir(cls_path):
            img_path = cls_path / file
            if img_path.is_dir():
                continue
            try:
                with Image.open(img_path) as im:
                    im = im.convert("RGB")
                    im = im.resize(size)
                    im.save(img_path)
            except:
                if img_path.is_file():
                    img_path.unlink(missing_ok=True)

def remove_duplicates(folder: Path):
    print("\nRemoving duplicates...")
    seen = set()
    for cls in tqdm(sorted(os.listdir(folder))):
        cls_path = folder / cls
        if not cls_path.is_dir():
            continue
        for file in os.listdir(cls_path):
            img_path = cls_path / file
            if img_path.is_dir():
                continue
            try:
                with open(img_path, "rb") as f:
                    h = hashlib.md5(f.read()).hexdigest()
                if h in seen and img_path.is_file():
                    img_path.unlink(missing_ok=True)
                else:
                    seen.add(h)
            except:
                if img_path.is_file():
                    img_path.unlink(missing_ok=True)

def denoise(folder: Path):
    print("\nDenoising images...")
    for cls in tqdm(sorted(os.listdir(folder))):
        cls_path = folder / cls
        if not cls_path.is_dir():
            continue
        for file in os.listdir(cls_path):
            img_path = cls_path / file
            if img_path.is_dir():
                continue
            try:
                img = cv2.imread(str(img_path))
                if img is None:
                    continue
                img = cv2.fastNlMeansDenoisingColored(img, None, 10,10,7,21)
                cv2.imwrite(str(img_path), img)
            except:
                pass

def show_balance(folder: Path):
    print("\nClass Balance:")
    for cls in sorted(os.listdir(folder)):
        cls_path = folder / cls
        if cls_path.is_dir():
            count = len([f for f in os.listdir(cls_path) if (cls_path / f).is_file()])
            print(f"{cls}: {count}")

# Run cleaning
remove_corrupted(CLEAN_DIR)
fix_format(CLEAN_DIR, IMG_SIZE)
remove_duplicates(CLEAN_DIR)
denoise(CLEAN_DIR)
show_balance(CLEAN_DIR)

print("\nCleaning pipeline completed.")

# -----------------------------
# Data generators
train_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=VAL_SPLIT,
    rotation_range=25,
    width_shift_range=0.15,
    height_shift_range=0.15,
    zoom_range=0.25,
    shear_range=0.2,
    horizontal_flip=True,
    brightness_range=(0.7, 1.3)
)

val_datagen = ImageDataGenerator(
    rescale=1./255,
    validation_split=VAL_SPLIT
)

train_gen = train_datagen.flow_from_directory(
    str(CLEAN_DIR),
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    subset="training",
    seed=SEED
)

val_gen = val_datagen.flow_from_directory(
    str(CLEAN_DIR),
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="categorical",
    subset="validation",
    seed=SEED,
    shuffle=False
)

num_classes = train_gen.num_classes
class_indices = train_gen.class_indices
idx_to_class = {v:k for k,v in class_indices.items()}

# -----------------------------
# Build EfficientNetB0 model
base_model = EfficientNetB0(include_top=False, weights="imagenet", input_shape=(224,224,3))
base_model.trainable = False

inputs = layers.Input(shape=(224,224,3))
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.4)(x)
outputs = layers.Dense(num_classes, activation="softmax")(x)
model = models.Model(inputs, outputs)

model.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
              loss="categorical_crossentropy",
              metrics=["accuracy"])

callbacks = [
    ModelCheckpoint("best_efficientnetb0.h5", monitor="val_accuracy", save_best_only=True, mode="max"),
    EarlyStopping(monitor="val_accuracy", patience=6, restore_best_weights=True),
    ReduceLROnPlateau(monitor="val_loss", factor=0.5, patience=3, min_lr=1e-6)
]

# -----------------------------
# Train model
history = model.fit(train_gen, epochs=30, validation_data=val_gen, callbacks=callbacks)

# Fine-tuning
base_model.trainable = True
for layer in base_model.layers[:-40]:
    layer.trainable = False

model.compile(optimizer=tf.keras.optimizers.Adam(5e-5),
              loss="categorical_crossentropy",
              metrics=["accuracy"])

history_ft = model.fit(train_gen, epochs=15, validation_data=val_gen, callbacks=callbacks)

# -----------------------------
# Evaluate model
val_gen.reset()
y_true = val_gen.classes
y_pred_probs = model.predict(val_gen)
y_pred = np.argmax(y_pred_probs, axis=1)

print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=[idx_to_class[i] for i in range(num_classes)]))

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10,8))
sns.heatmap(cm, cmap="Blues", annot=False)
plt.title("Confusion Matrix")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()

# -----------------------------
# Final accuracy
loss, acc = model.evaluate(val_gen, verbose=0)
print(f"\nFinal Validation Accuracy: {acc*100:.2f}%")