In [None]:
import cv2 as cv
import mediapipe as mp
import numpy as np
import pandas as pd
import time
import os
import json
import argparse
import csv
import copy
import random 
import math
from scipy.spatial.distance import cosine
import itertools
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score
from mediapipe.tasks.python.components import containers
import pickle

In [280]:
base_options_hand = python.BaseOptions(model_asset_path='hand_landmarker.task')
options_hand = vision.HandLandmarkerOptions(base_options=base_options_hand,
                                       num_hands=2,min_hand_detection_confidence=0.5,
    min_hand_presence_confidence=0.5,
    min_tracking_confidence=0.5,
    running_mode=vision.RunningMode.VIDEO)
detector_hand = vision.HandLandmarker.create_from_options(options_hand)

mp_hands = mp.tasks.vision.HandLandmarksConnections
mp_drawing = mp.tasks.vision.drawing_utils

In [281]:
# --- 1. Hand Landmarker (Existing) ---
base_options_hand = python.BaseOptions(model_asset_path='hand_landmarker.task')
options_hand = vision.HandLandmarkerOptions(
    base_options=base_options_hand,
    num_hands=2,
    min_hand_detection_confidence=0.5,
    min_hand_presence_confidence=0.5,
    min_tracking_confidence=0.5,
    running_mode=vision.RunningMode.VIDEO
)
detector_hand = vision.HandLandmarker.create_from_options(options_hand)

# # --- 2. Face Detector (BlazeFace) ---
# # Finds where the face is so we can crop it
# base_options_det = python.BaseOptions(model_asset_path='blaze_face_short_range.tflite')
# options_det = vision.FaceDetectorOptions(
#     base_options=base_options_det,
#     running_mode=vision.RunningMode.IMAGE # We run this manually on frames
# )
# face_detector = vision.FaceDetector.create_from_options(options_det)

# # --- 3. Image Embedder (MobileNet) ---
# # Converts the cropped face image into a vector for recognition
# base_options_emb = python.BaseOptions(model_asset_path='face_embedder.tflite')
# options_emb = vision.ImageEmbedderOptions(
#     base_options=base_options_emb,
#     l2_normalize=True, # Critical for cosine similarity
#     quantize=True,
#     running_mode=vision.RunningMode.IMAGE
# )
# face_embedder = vision.ImageEmbedder.create_from_options(options_emb)

mp_hands = mp.tasks.vision.HandLandmarksConnections
mp_drawing = mp.tasks.vision.drawing_utils

In [282]:
class RobustFaceManager:
    def __init__(self, database_file='user_db.json'):
        self.user_database = {} 
        self.similarity_threshold = 0.65 # Threshold
        self.database_file = database_file
        
        # 1. Initialize Face Detector
        base_options_det = python.BaseOptions(model_asset_path='blaze_face_short_range.tflite')
        options_det = vision.FaceDetectorOptions(
            base_options=base_options_det,
            running_mode=vision.RunningMode.IMAGE
        )
        self.detector = vision.FaceDetector.create_from_options(options_det)

        # 2. Initialize Image Embedder
        base_options_emb = python.BaseOptions(model_asset_path='face_embedder.tflite')
        options_emb = vision.ImageEmbedderOptions(
            base_options=base_options_emb,
            l2_normalize=True, 
            quantize=True,
            running_mode=vision.RunningMode.IMAGE
        )
        self.embedder = vision.ImageEmbedder.create_from_options(options_emb)
        
        self.load_database()

    def load_database(self):
        if os.path.exists(self.database_file):
            try:
                with open(self.database_file, 'r') as f:
                    raw_db = json.load(f)
                
                # --- CLEANING STEP ---
                # Remove 'zero vectors' that cause divide-by-zero errors
                self.user_database = {}
                for name, vectors in raw_db.items():
                    valid_vectors = []
                    for v in vectors:
                        # Convert to numpy and check if it has data (norm > 0)
                        v_np = np.array(v, dtype=float)
                        if np.linalg.norm(v_np) > 0.001: 
                            valid_vectors.append(v)
                    
                    if valid_vectors:
                        self.user_database[name] = valid_vectors
                        
                print(f"Loaded {len(self.user_database)} users (cleaned bad data).")
            except Exception as e:
                print(f"Database error: {e}")
                self.user_database = {}

    def save_database(self):
        serializable_db = {}
        for name, vectors in self.user_database.items():
            # Ensure we save simple lists, not numpy arrays
            serializable_db[name] = [v if isinstance(v, list) else v.tolist() for v in vectors]
        
        with open(self.database_file, 'w') as f:
            json.dump(serializable_db, f)
        print("Database saved.")

    def get_face_box(self, mp_image):
        detection_result = self.detector.detect(mp_image)
        if detection_result.detections:
            return detection_result.detections[0].bounding_box
        return None

    def get_face_embedding(self, mp_image):
        # A. DETECT
        detection_result = self.detector.detect(mp_image)
        if not detection_result.detections: return None 

        # B. CROP
        detection = detection_result.detections[0]
        bbox = detection.bounding_box
        
        image_np = mp_image.numpy_view()
        h, w, c = image_np.shape
        x, y = max(0, bbox.origin_x), max(0, bbox.origin_y)
        w_box, h_box = min(w - x, bbox.width), min(h - y, bbox.height)

        face_crop_np = image_np[y:y+h_box, x:x+w_box]
        if face_crop_np.size == 0: return None

        # C. EMBED
        mp_face_crop = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.ascontiguousarray(face_crop_np))
        embedding_result = self.embedder.embed(mp_face_crop)
        
        if embedding_result.embeddings:
            return embedding_result.embeddings[0].embedding
        return None

    def _safe_cosine_similarity(self, v1, v2):
        """
        Manually calculates similarity to avoid Divide-By-Zero crashes.
        Returns 0.0 if vectors are invalid.
        """
        v1 = np.array(v1, dtype=float)
        v2 = np.array(v2, dtype=float)
        
        dot = np.dot(v1, v2)
        norm1 = np.linalg.norm(v1)
        norm2 = np.linalg.norm(v2)
        
        if norm1 == 0 or norm2 == 0:
            return 0.0 # Bad vector, return 0 similarity
            
        return dot / (norm1 * norm2)

    def register_user_sample(self, name, mp_image):
        embedding = self.get_face_embedding(mp_image)
        if embedding is not None:
            # Check norm before saving!
            if np.linalg.norm(embedding) < 0.001: return False

            if name not in self.user_database:
                self.user_database[name] = []
            
            if isinstance(embedding, np.ndarray): embedding = embedding.tolist()
            self.user_database[name].append(embedding)
            return True
        return False

    def identify(self, mp_image):
        unknown_vector = self.get_face_embedding(mp_image)
        if unknown_vector is None: return "Unknown", 0.0

        best_user = "Unknown"
        highest_similarity = -1

        for name, saved_list in self.user_database.items():
            for saved_vector in saved_list:
                # USE SAFE SIMILARITY
                sim = self._safe_cosine_similarity(unknown_vector, saved_vector)
                
                if sim > highest_similarity:
                    highest_similarity = sim
                    best_user = name

        if highest_similarity > self.similarity_threshold:
            return best_user, highest_similarity
        else:
            return "Unknown", highest_similarity

In [283]:
face_manager = RobustFaceManager()

In [284]:
def enroll_new_user(user_name, capture_duration=10):
    print(f"Starting enrollment for {user_name}. Please move your head slowly...")
    
    collected_embeddings = []
    start_time = time.time()
    
    while (time.time() - start_time) < capture_duration:
        success, frame = cap.read()
        if not success: break
        
        # 1. Get Embedding
        mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=frame)
        timestamp = int(time.time() * 1000)
        result = face_embedder.embed(mp_image) # Using non-video mode for simplicity here, or use embed_for_video
        
        if result.embeddings:
            # Save the embedding object (specifically the numpy array part)
            # MediaPipe returns a float array, convert to list or keep as numpy
            vector = result.embeddings[0].embedding
            collected_embeddings.append(vector)
            
        cv.imshow(f"Enrolling {user_name}", frame)
        cv.waitKey(5)
        
    # 2. Save to Manager
    face_manager.register_user_samples(user_name, collected_embeddings)
    print(f"Enrollment complete! Captured {len(collected_embeddings)} angles.")

In [285]:
def draw_landmarks_on_image(rgb_image, detection_result,prediction,confidence,brect):
  hand_landmarks_list = detection_result.hand_landmarks
  handedness_list = detection_result.handedness
  annotated_image = np.copy(rgb_image)

  # Loop through the detected hands to visualize.
  for idx in range(len(hand_landmarks_list)):
    hand_landmarks = hand_landmarks_list[idx]
    handedness = handedness_list[idx]

    # Draw the hand landmarks.
    mp_drawing.draw_landmarks(
      annotated_image,
      hand_landmarks,
      mp_hands.HAND_CONNECTIONS)
    if len(brect) > 0:      
      cv.putText(annotated_image, f"{prediction}({confidence:.3f})",(brect[0],brect[1]-5),cv.FONT_HERSHEY_PLAIN,1,(0,0,0),1)

  return annotated_image

In [286]:
def get_landmarks(image):
    return image.hand_landmarks #list of lists

In [287]:
def normalize(hand_landmarks):
    temp_landmark_list = copy.deepcopy(hand_landmarks)

    base_x, base_y = 0, 0
    points = []

    for i,landmark in enumerate(temp_landmark_list):
        px = landmark.x if hasattr(landmark, 'x') else landmark[0]
        py = landmark.y if hasattr(landmark, 'y') else landmark[1]

        points.append([px, py])

        if i == 0:
            base_x, base_y = px, py

    for point in points:
        point[0] = point[0] - base_x
        point[1] = point[1] - base_y

    max_value = max(list(map(abs, itertools.chain.from_iterable(points))))

    if max_value == 0: 
        max_value = 1

    def normalize_(n):
        return n / max_value

    for point in points:
        point[0] = normalize_(point[0])
        point[1] = normalize_(point[1])

    flattened_list = list(itertools.chain.from_iterable(points))

    return flattened_list

In [288]:
def get_canonical_landmarks(landmarks, handedness):
    canonical_points = []
    need_flip = handedness

    for landmark in landmarks:
        
        x = landmark.x
        y = landmark.y
    
        if need_flip:
            x = 1.0 - x
            
        canonical_points.append([x, y])
    return canonical_points

In [289]:
def augment_landmarks_list(landmark_list):
    augmented_rows = []
    augmented_rows.append(landmark_list)

    for _ in range(4): 
        new_landmarks = []

        theta = math.radians(random.uniform(-10, 10))
        c, s = math.cos(theta), math.sin(theta)

        scale = random.uniform(0.9, 1.1)

        for i in range(0, len(landmark_list), 2):
            x = landmark_list[i]
            y = landmark_list[i+1]
            
            x_new = (x * c) - (y * s)
            y_new = (x * s) + (y * c)
            
            x_new *= scale
            y_new *= scale

            new_landmarks.extend([x_new, y_new])
            
        augmented_rows.append(new_landmarks)
        
    return augmented_rows

In [290]:
def train_custom_model(csv_path='gesture_data.csv', model_save_path='gesture_model.pkl'):
    print(f"Loading data from {csv_path}...")
    try:
        df = pd.read_csv(csv_path, header=None)
    except FileNotFoundError:
        print("Error: CSV file not found. Have you recorded any gestures yet?")
        return None

    # Separate Features (X) and Labels (y)
    X = df.iloc[:, 1:].values  
    y = df.iloc[:, 0].values
    
    # --- CHECK: Do we have at least 2 classes? ---
    unique_classes = np.unique(y)
    if len(unique_classes) < 2:
        print(f"CANNOT TRAIN YET: Found only 1 class ({unique_classes[0]}).")
        print(">>> PLEASE RECORD A 'NEUTRAL' GESTURE NEXT! <<<")
        return None 
    # ---------------------------------------------

    # Stratify needs at least 2 samples per class. Safety check:
    try:
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, stratify=y, random_state=42
        )
    except ValueError:
        print("Warning: Not enough data to split perfectly. Training on full dataset.")
        X_train, y_train = X, y
        X_test, y_test = [], []

    # FIX: Print statement is now AFTER X_train is defined
    print(f"Training on {len(X_train)} samples with classes: {unique_classes}")

    model = SVC(kernel='linear', probability=True) 
    model.fit(X_train, y_train)

    if len(X_test) > 0:
        y_pred = model.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        print(f"Model Accuracy: {accuracy * 100:.2f}%")

    with open(model_save_path, 'wb') as f:
        pickle.dump(model, f)
    
    print(f"Model saved to {model_save_path}!")
    return model

In [291]:
train_custom_model()

Loading data from gesture_data.csv...
Error: CSV file not found. Have you recorded any gestures yet?


In [292]:
def calc_bounding_rect(image, landmarks):
    image_width, image_height = image.shape[1], image.shape[0]

    coords = np.array([[min(int(l.x * image_width), image_width - 1), 
                        min(int(l.y * image_height), image_height - 1)] 
                       for l in landmarks])

    x, y, w, h = cv.boundingRect(coords)

    return [x, y, x + w, y + h]

In [293]:
def draw_bounding_rect(use_brect, image, brect):
    if use_brect:
        # Outer rectangle
        cv.rectangle(image, (brect[0], brect[1]), (brect[2], brect[3]),
                     (0, 0, 0), 2)

    return image

In [294]:
# --- INITIALIZATION ---
cap = cv.VideoCapture(0)
face_manager = RobustFaceManager() # Ensures DB is loaded and clean

# State Variables
current_user = None
current_hand_model = None
recording_mode = False      
enrollment_mode = False     
frames_recorded = 0
frame_count = 0 
missed_face_frames = 0 # <--- NEW: Memory Counter

print("Controls:")
print("  'e' -> Enroll a NEW USER")
print("  'r' -> Record a NEW GESTURE")
print("  's' -> STOP and SAVE")
print("  'q' -> Quit")

while True:
    success, img = cap.read()
    if not success: break
    frame_count += 1
    
    key = cv.waitKey(5) & 0xFF
    if key == ord('q') or key == 27: break

    # Prepare Image
    debug_image = copy.deepcopy(img)
    image_rgb = cv.cvtColor(img, cv.COLOR_BGR2RGB)
    mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_rgb)
    frame_timestamp_ms = int(time.time() * 1000)

    # ----------------------------------------------------
    # PHASE 1: VISUALS & PERSISTENCE CHECK
    # ----------------------------------------------------
    face_bbox = face_manager.get_face_box(mp_image)
    
    if face_bbox:
        # Face Found -> Reset counter
        missed_face_frames = 0 
        x, y, w, h = face_bbox.origin_x, face_bbox.origin_y, face_bbox.width, face_bbox.height
        cv.rectangle(debug_image, (x, y), (x + w, y + h), (255, 255, 0), 2)
    else:
        # No Face -> Increment counter
        missed_face_frames += 1
        
        # LOGIC: Only reset if we haven't seen a face for ~1 second (20 frames)
        # AND we are NOT currently enrolling a user.
        if missed_face_frames > 20 and not enrollment_mode:
            if current_user is not None:
                print("User left. Resetting...")
                current_user = None
                current_hand_model = None
    
    # ----------------------------------------------------
    # PHASE 2: RECOGNITION & ENROLLMENT
    # ----------------------------------------------------
    
    # A. ENROLLMENT (Always run if active, even if face momentarily lost)
    if enrollment_mode and face_bbox:
        success = face_manager.register_user_sample(current_user, mp_image)
        if success:
            cv.putText(debug_image, f"ENROLLING {current_user}...", (10, 50), 
                       cv.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 255), 2)
            count = len(face_manager.user_database.get(current_user, []))
            cv.putText(debug_image, f"Samples: {count}", (10, 90), 
                       cv.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)

    # B. RECOGNITION (Run periodically if face is present)
    elif face_bbox and (frame_count % 10 == 0) and not recording_mode:
        user_name, score = face_manager.identify(mp_image)
        
        if user_name != "Unknown":
            # Only load if it's a DIFFERENT user
            if user_name != current_user:
                print(f"Switched User: {user_name} (Score: {score:.2f})")
                current_user = user_name
                try:
                    with open(f'{current_user}_model.pkl', 'rb') as f:
                        current_hand_model = pickle.load(f)
                    print(f"Loaded gestures for {current_user}")
                except FileNotFoundError:
                    print(f"No gestures found for {current_user}")
                    current_hand_model = None
        else:
            
            pass

    # Visual Status
    if current_user:
        cv.putText(debug_image, f"User: {current_user}", (10, 30), 
                   cv.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    else:
        cv.putText(debug_image, "User: Unknown", (10, 30), 
                   cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

    # ----------------------------------------------------
    # PHASE 3: HAND GESTURES
    # ----------------------------------------------------
    hand_result = detector_hand.detect_for_video(mp_image, frame_timestamp_ms)
    
    if hand_result.hand_landmarks:
        for i, hand_landmarks in enumerate(hand_result.hand_landmarks):
            # Hand BBox
            h_img, w_img, _ = debug_image.shape
            x_vals = [lm.x for lm in hand_landmarks]
            y_vals = [lm.y for lm in hand_landmarks]
            brect = [int(min(x_vals)*w_img), int(min(y_vals)*h_img), int(max(x_vals)*w_img), int(max(y_vals)*h_img)]
            
            # Logic
            handedness_obj = hand_result.handedness[i][0]
            is_left = (handedness_obj.category_name == "Left") 
            canonical_points = get_canonical_landmarks(hand_landmarks, is_left) 
            normalized_flat = normalize(canonical_points)

            # Record
            if recording_mode and current_user:
                batch = augment_landmarks_list(normalized_flat)
                with open(f"{current_user}_data.csv", 'a', newline='') as f:
                    writer = csv.writer(f)
                    for row in batch:
                        writer.writerow([current_label] + row)
                frames_recorded += 1
                cv.putText(debug_image, f"REC: {current_label} ({frames_recorded})", (10, 70), 
                           cv.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

            # Predict
            elif current_hand_model:
                prediction = current_hand_model.predict([normalized_flat])
                confidence = current_hand_model.predict_proba([normalized_flat])
                pred_label = prediction[0]
                max_conf = np.max(confidence)

                if max_conf > 0.6:
                    cv.rectangle(debug_image, (brect[0], brect[1]), (brect[2], brect[3]), (0, 255, 0), 2)
                    debug_image = draw_landmarks_on_image(debug_image, hand_result, pred_label, max_conf, brect)

    # ----------------------------------------------------
    # INPUT CONTROLS
    # ----------------------------------------------------
    if key == ord('e'): # Enroll
        name = input("Enter Name: ")
        print("Flushing buffer...")
        for _ in range(5): cap.read()
        current_user = name
        enrollment_mode = True
        recording_mode = False
        missed_face_frames = 0 # Reset counter so we don't logout immediately

    elif key == ord('r'): # Record Gesture
        if current_user:
            current_label = input("Enter Gesture Name: ")
            print("Flushing buffer...")
            for _ in range(5): cap.read()
            recording_mode = True
            enrollment_mode = False
            frames_recorded = 0
        else:
            print("Enroll a user first!")

    elif key == ord('s'): # Save
        if enrollment_mode:
            face_manager.save_database()
            enrollment_mode = False
            print("User Saved.")
        elif recording_mode:
            recording_mode = False
            # Train
            new_model = train_custom_model(csv_path=f"{current_user}_data.csv", 
                                           model_save_path=f"{current_user}_model.pkl")
            if new_model: current_hand_model = new_model

    cv.imshow("System", debug_image)

cap.release()
cv.destroyAllWindows()

Controls:
  'e' -> Enroll a NEW USER
  'r' -> Record a NEW GESTURE
  's' -> STOP and SAVE
  'q' -> Quit
Flushing buffer...
Database saved.
User Saved.
Flushing buffer...
Loading data from Garv_data.csv...
CANNOT TRAIN YET: Found only 1 class (open palm).
>>> PLEASE RECORD A 'NEUTRAL' GESTURE NEXT! <<<
Flushing buffer...
Loading data from Garv_data.csv...
Training on 2064 samples with classes: ['open palm' 'up']
Model Accuracy: 98.84%
Model saved to Garv_model.pkl!
Flushing buffer...
Loading data from Garv_data.csv...
Training on 3200 samples with classes: ['nothing' 'open palm' 'up']
Model Accuracy: 97.38%
Model saved to Garv_model.pkl!
Flushing buffer...
Loading data from Garv_data.csv...
Training on 4152 samples with classes: ['nothing' 'open palm' 'peace' 'up']
Model Accuracy: 97.50%
Model saved to Garv_model.pkl!
Flushing buffer...
Loading data from Garv_data.csv...
Training on 4784 samples with classes: ['left' 'nothing' 'open palm' 'peace' 'up']
Model Accuracy: 97.49%
Model saved