In [None]:
!pip install mtcnn opencv-python

In [None]:
import os
import cv2
import numpy as np
from mtcnn import MTCNN
import tensorflow as tf


In [None]:
IMG_SIZE = (224, 224)
BATCH_SIZE = 32

detector = MTCNN()
BLUR_THRESHOLD = 80   # increase = strict filtering


In [None]:
def is_blurry(image, thresh=BLUR_THRESHOLD):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    val = cv2.Laplacian(gray, cv2.CV_64F).var()
    return val < thresh


In [None]:
def clean_image_generator(directory, batch_size=BATCH_SIZE):
    class_names = sorted(os.listdir(directory))
    class_to_index = {name: idx for idx, name in enumerate(class_names)}
    
    images = []
    labels = []

    while True:
        for class_name in class_names:
            class_path = os.path.join(directory, class_name)
            if not os.path.isdir(class_path):
                continue

            for file in os.listdir(class_path):
                path = os.path.join(class_path, file)

                # Load image
                img = cv2.imread(path)
                if img is None:
                    continue  # corrupted image

                if is_blurry(img):
                    continue  # skip blurry

                rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                faces = detector.detect_faces(rgb)

                if len(faces) != 1:
                    continue  # skip multi-face or no-face

                # Face box
                x, y, w, h = faces[0]['box']
                face_crop = rgb[y:y+h, x:x+w]

                # Resize for model
                face_crop = cv2.resize(face_crop, IMG_SIZE)
                face_crop = face_crop / 255.0

                images.append(face_crop)
                labels.append(class_to_index[class_name])

                # Yield batch
                if len(images) == batch_size:
                    yield np.array(images), tf.keras.utils.to_categorical(labels, num_classes=len(class_names))
                    images, labels = [], []


In [None]:
class_names = sorted([d for d in os.listdir("images") if os.path.isdir(os.path.join("images", d))])
num_classes = len(class_names)
print("Classes:", class_names)

In [None]:
train_gen = clean_image_generator("images", batch_size=BATCH_SIZE)

In [None]:
base = tf.keras.applications.MobileNetV2(
    weights='imagenet',
    include_top=False,
    input_shape=(224, 224, 3)
)
base.trainable = False

model = tf.keras.Sequential([
    base,
    tf.keras.layers.GlobalAveragePooling2D(),
    tf.keras.layers.Dense(256, activation='relu'),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

In [None]:
model.fit(
    train_gen,
    steps_per_epoch=15,   # change based on data size
    epochs=15
)

Epoch 1/15
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m117s[0m 8s/step - accuracy: 0.0167 - loss: 4.6070
Epoch 2/15
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m118s[0m 8s/step - accuracy: 0.1250 - loss: 2.7244
Epoch 3/15
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m133s[0m 9s/step - accuracy: 0.2146 - loss: 2.4819
Epoch 4/15
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m259s[0m 17s/step - accuracy: 0.2500 - loss: 2.3714
Epoch 5/15
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m237s[0m 16s/step - accuracy: 0.2333 - loss: 2.3049
Epoch 6/15
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m144s[0m 8s/step - accuracy: 0.3354 - loss: 2.1247
Epoch 7/15
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m119s[0m 8s/step - accuracy: 0.4042 - loss: 1.9333
Epoch 8/15
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m268s[0m 19s/step - accuracy: 0.5104 - loss: 1.7340
Epoch 9/15
[1m15/15[0m [32m━━━━━━━━━━━━━━━

<keras.src.callbacks.history.History at 0x24cda9f9e80>

In [None]:
# correct way to save model
model.save("cricknet_model.keras")

# save class dictionary
class_dict = {
    'bhuvneshwar_kumar': 0,
    'dinesh_karthik': 1,
    'hardik_pandya': 2,
    'jasprit_bumrah': 3,
    'k._l._rahul': 4,
    'kedar_jadhav': 5,
    'kuldeep_yadav': 6,
    'mohammed_shami': 7,
    'ms_dhoni': 8,
    'ravindra_jadeja': 9,
    'rohit_sharma': 10,
    'shikhar_dhawan': 11,
    'vijay_shankar': 12,
    'virat_kohli': 13,
    'yuzvendra_chahal': 14
}

import json
with open("class_dictionary.json", "w") as f:
    json.dump(class_dict, f)
