In [None]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw, ImageFont
from sklearn.metrics import confusion_matrix
import os
import seaborn as sns
from facenet_pytorch import InceptionResnetV1, MTCNN
import torchvision.transforms as transforms
import cv2
import mediapipe as mp
import csv

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Model initialization
model = InceptionResnetV1(pretrained='vggface2').eval().to(device)
mtcnn = MTCNN(image_size=160, margin=40, device=device, keep_all=False)

# MediaPipe Face Mesh initialization
mp_face_mesh = mp.solutions.face_mesh
face_mesh = mp_face_mesh.FaceMesh(
    static_image_mode=True,
    max_num_faces=1,
    refine_landmarks=True,
    min_detection_confidence=0.5
)

fallback_transform = transforms.Compose([
    transforms.Resize((160, 160)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Pose mapping
POSE_MAPPING = {
    '00': 'frontal',
    '01': 'up',
    '02': 'down',
    '03': 'left',
    '04': 'right'
}

In [None]:

# ============================================================================
# DISTANCE METRIC SELECTION
# ============================================================================
USE_L2_DISTANCE = True  # Set to False to use Cosine Similarity

if USE_L2_DISTANCE:
    # For L2: LOWER is better (more similar)
    # Typical range: 0.6 to 1.2 (lower threshold = more strict)
    SIMILARITY_THRESHOLD = 0.95  # Accept if L2 distance <= 0.95
    print(f"\n{'='*80}")
    print("USING L2 DISTANCE (Euclidean)")
    print(f"{'='*80}")
    print(f"Threshold: {SIMILARITY_THRESHOLD} (accept if distance <= threshold)")
    print(f"Lower distance = MORE similar")
    print(f"Range: 0.0 (identical) to ~2.0 (very different)")
else:
    # For Cosine: HIGHER is better (more similar)  
    # Typical range: 0.4 to 0.7
    SIMILARITY_THRESHOLD = 0.55  # Accept if cosine similarity >= 0.55
    print(f"\n{'='*80}")
    print("USING COSINE SIMILARITY")
    print(f"{'='*80}")
    print(f"Threshold: {SIMILARITY_THRESHOLD} (accept if similarity >= threshold)")
    print(f"Higher similarity = MORE similar")
    print(f"Range: -1.0 to 1.0 (typically 0.3 to 0.9 for faces)")

print(f"{'='*80}\n")

"""
L2 DISTANCE vs COSINE SIMILARITY:
==================================

L2 DISTANCE (Euclidean Distance):
  Formula: sqrt(sum((emb1 - emb2)^2))
  Range: 0 to infinity (typically 0.4 to 2.0 for faces)
  Interpretation: LOWER = MORE similar
    - 0.0-0.6: Very similar (same person)
    - 0.6-0.9: Similar (likely same person)
    - 0.9-1.2: Somewhat similar (threshold zone)
    - >1.2: Different people
  Threshold: distance <= threshold → ACCEPT

COSINE SIMILARITY:
  Formula: dot(emb1, emb2) / (||emb1|| * ||emb2||)
  Range: -1 to +1 (typically 0.3 to 0.9 for faces)
  Interpretation: HIGHER = MORE similar
    - 0.8-1.0: Very similar (same person)
    - 0.6-0.8: Similar (likely same person)
    - 0.4-0.6: Somewhat similar (threshold zone)
    - <0.4: Different people
  Threshold: similarity >= threshold → ACCEPT

For L2-normalized embeddings (like ours):
  L2_distance = sqrt(2 * (1 - cosine_similarity))
  
Recommended Thresholds:
  L2: 0.85-1.0 (conservative), 0.7-0.85 (balanced), 0.6-0.7 (lenient)
  Cosine: 0.55-0.60 (conservative), 0.50-0.55 (balanced), 0.45-0.50 (lenient)
"""

In [None]:


def estimate_head_pose_mediapipe(image_path):
    """Estimate head pose using MediaPipe Face Mesh"""
    img = cv2.imread(image_path)
    if img is None:
        return None, None, None, 'unknown'
    
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    results = face_mesh.process(img_rgb)
    
    if not results.multi_face_landmarks:
        return None, None, None, 'unknown'
    
    face_landmarks = results.multi_face_landmarks[0]
    img_h, img_w = img.shape[:2]
    
    def get_2d_point(idx):
        return np.array([
            face_landmarks.landmark[idx].x * img_w,
            face_landmarks.landmark[idx].y * img_h
        ])
    
    left_eye = get_2d_point(33)
    right_eye = get_2d_point(263)
    nose_tip = get_2d_point(1)
    left_mouth = get_2d_point(61)
    right_mouth = get_2d_point(291)
    
    # Calculate YAW (left-right turn)
    eye_center_x = (left_eye[0] + right_eye[0]) / 2
    face_width = np.linalg.norm(right_eye - left_eye)
    nose_to_center = nose_tip[0] - eye_center_x
    yaw_ratio = nose_to_center / (face_width / 2) if face_width > 0 else 0
    yaw = np.clip(yaw_ratio * 45, -90, 90)
    
    # Calculate PITCH (up-down tilt)
    eye_center_y = (left_eye[1] + right_eye[1]) / 2
    mouth_center_y = (left_mouth[1] + right_mouth[1]) / 2
    face_height = abs(mouth_center_y - eye_center_y)
    expected_nose_y = eye_center_y + 0.4 * face_height
    nose_deviation_y = (nose_tip[1] - expected_nose_y) / face_height if face_height > 0 else 0
    pitch = np.clip(nose_deviation_y * 60, -45, 45)
    
    # Calculate ROLL (head tilt)
    eye_angle = np.arctan2(right_eye[1] - left_eye[1], right_eye[0] - left_eye[0])
    roll = np.degrees(eye_angle)
    
    pose_category = categorize_pose(yaw, pitch, roll)
    return yaw, pitch, roll, pose_category

def categorize_pose(yaw, pitch, roll):
    """Categorize pose based on angles"""
    if yaw is None or pitch is None:
        return 'unknown'
    
    yaw_threshold = 20
    pitch_threshold = 15
    
    if abs(yaw) > yaw_threshold:
        return 'right' if yaw > 0 else 'left'
    elif abs(pitch) > pitch_threshold:
        return 'down' if pitch > 0 else 'up'
    else:
        return 'frontal'

def draw_pose_annotation(image_path, yaw, pitch, roll, pose_category):
    """Draw pose annotation with 3D axes"""
    img = Image.open(image_path).convert('RGB')
    img_cv = cv2.imread(image_path)
    
    if img_cv is None:
        return img
    
    img_h, img_w = img_cv.shape[:2]
    img_rgb = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
    results = face_mesh.process(img_rgb)
    
    if yaw is not None and results.multi_face_landmarks:
        face_landmarks = results.multi_face_landmarks[0]
        
        nose_tip = face_landmarks.landmark[1]
        nose_x = int(nose_tip.x * img_w)
        nose_y = int(nose_tip.y * img_h)
        
        yaw_rad = np.radians(yaw)
        pitch_rad = np.radians(pitch)
        roll_rad = np.radians(roll)
        
        axis_length = min(img_w, img_h) // 5
        
        # Calculate axis endpoints
        x_end_x = int(nose_x + axis_length * np.cos(pitch_rad))
        x_end_y = int(nose_y - axis_length * np.sin(pitch_rad))
        
        y_end_x = int(nose_x + axis_length * np.sin(yaw_rad))
        y_end_y = int(nose_y)
        
        z_end_x = int(nose_x + axis_length * 0.3 * np.sin(yaw_rad))
        z_end_y = int(nose_y - axis_length * 0.3 * np.cos(pitch_rad))
        
        # Draw axes
        cv2.line(img_cv, (nose_x, nose_y), (z_end_x, z_end_y), (255, 0, 0), 3)
        cv2.circle(img_cv, (z_end_x, z_end_y), 5, (255, 0, 0), -1)
        
        cv2.line(img_cv, (nose_x, nose_y), (x_end_x, x_end_y), (0, 0, 255), 3)
        cv2.circle(img_cv, (x_end_x, x_end_y), 5, (0, 0, 255), -1)
        
        cv2.line(img_cv, (nose_x, nose_y), (y_end_x, y_end_y), (0, 255, 0), 3)
        cv2.circle(img_cv, (y_end_x, y_end_y), 5, (0, 255, 0), -1)
        
        cv2.circle(img_cv, (nose_x, nose_y), 7, (255, 255, 255), -1)
        cv2.circle(img_cv, (nose_x, nose_y), 7, (0, 0, 0), 2)
        
        img = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
    
    draw = ImageDraw.Draw(img)
    
    if yaw is not None:
        text = f"Yaw: {yaw:.1f}° (Green)\nPitch: {pitch:.1f}° (Red)\nRoll: {roll:.1f}°\nPose: {pose_category}"
    else:
        text = f"Pose: {pose_category}"
    
    try:
        font = ImageFont.truetype("arial.ttf", 12)
    except:
        font = ImageFont.load_default()
    
    bbox = draw.textbbox((10, 10), text, font=font)
    draw.rectangle(bbox, fill=(0, 0, 0, 180))
    draw.text((10, 10), text, fill=(0, 255, 255), font=font)
    
    legend_y = img.height - 80
    legend_text = "Axes:\nRed = Pitch\nGreen = Yaw\nBlue = Forward"
    legend_bbox = draw.textbbox((10, legend_y), legend_text, font=font)
    draw.rectangle(legend_bbox, fill=(0, 0, 0, 180))
    draw.text((10, legend_y), legend_text, fill=(255, 255, 255), font=font)
    
    return img


In [None]:

def extract_embedding(img_path):
    """Extract face embedding from image"""
    img = Image.open(img_path).convert('RGB')
    face_tensor = mtcnn(img)
    
    if face_tensor is None:
        face_tensor = fallback_transform(img)
    
    with torch.no_grad():
        face_tensor = face_tensor.unsqueeze(0).to(device)
        embedding = model(face_tensor)
        embedding = F.normalize(embedding, p=2, dim=1)
    
    return embedding.cpu().numpy().flatten()

def calculate_similarity(emb1, emb2):
    """
    Calculate similarity between two embeddings
    Returns similarity score based on selected metric
    """
    if USE_L2_DISTANCE:
        # L2 Distance (Euclidean)
        distance = np.linalg.norm(emb1 - emb2)
        return distance  # Lower is better
    else:
        # Cosine Similarity
        similarity = np.dot(emb1, emb2)
        return similarity  # Higher is better

def attention_score(query_emb, gallery_embs, temperature=1.0):
    """Calculate attention-weighted similarity score"""
    sims = np.array([calculate_similarity(query_emb, g) for g in gallery_embs])
    
    if USE_L2_DISTANCE:
        # For L2: convert to similarity-like scores (negative distances)
        exp_sims = np.exp(-sims / temperature)
    else:
        # For Cosine: use as-is
        exp_sims = np.exp(sims / temperature)
    
    attention = exp_sims / np.sum(exp_sims)
    
    if USE_L2_DISTANCE:
        # Weighted average of distances (lower is better)
        S_p = np.sum(attention * sims)
    else:
        # Weighted average of similarities (higher is better)
        S_p = np.sum(attention * sims)
    
    return S_p, attention, sims

def pose_specific_score(query_emb, gallery_dict, estimated_pose):
    """Calculate similarity focusing on matching pose"""
    all_scores = {}
    
    for person_id, pose_data in gallery_dict.items():
        pose_sims = {}
        for pose_id, emb in pose_data.items():
            sim = calculate_similarity(query_emb, emb)
            pose_name = POSE_MAPPING.get(pose_id, pose_id)
            pose_sims[pose_name] = sim
        
        gallery_embs = list(pose_data.values())
        overall_score, attention, sims = attention_score(query_emb, gallery_embs)
        
        if estimated_pose in pose_sims:
            pose_specific = pose_sims[estimated_pose]
        else:
            if USE_L2_DISTANCE:
                pose_specific = min(pose_sims.values())  # Minimum distance
            else:
                pose_specific = max(pose_sims.values())  # Maximum similarity
        
        all_scores[person_id] = {
            'overall': overall_score,
            'pose_specific': pose_specific,
            'pose_sims': pose_sims,
            'attention': attention,
            'all_sims': sims
        }
    
    return all_scores

In [None]:
def meets_threshold(score):
    """Check if score meets the threshold based on metric type"""
    if USE_L2_DISTANCE:
        return score <= SIMILARITY_THRESHOLD  # Lower is better
    else:
        return score >= SIMILARITY_THRESHOLD  # Higher is better

def get_best_match(all_scores, method='pose_specific'):
    """Get best matching person based on scores"""
    if USE_L2_DISTANCE:
        # Find minimum distance
        best_person = min(all_scores.keys(), key=lambda p: all_scores[p][method])
    else:
        # Find maximum similarity
        best_person = max(all_scores.keys(), key=lambda p: all_scores[p][method])
    
    best_score = all_scores[best_person][method]
    return best_person, best_score


In [None]:



# UPDATE THIS PATH
BASE_PATH = 'E:/Projects/cv1'

# Build gallery
print("\n" + "="*80)
print("BUILDING GALLERY WITH MEDIAPIPE HEAD POSE ESTIMATION")
print("="*80)

gallery_ids = ['00', '01', '02', '03', '04', '05', '06']
gallery_dict = {}
gallery_images = {}
gallery_poses = {}

for person_id in gallery_ids:
    folder_path = os.path.join(BASE_PATH, "gallery", person_id)
    
    if not os.path.exists(folder_path):
        print(f"Warning: Folder not found: {folder_path}")
        continue
    
    pose_embeddings = {}
    person_images = {}
    person_poses = {}
    
    print(f"\nProcessing gallery person {person_id}:")
    for img_name in sorted(os.listdir(folder_path)):
        if not img_name.lower().endswith(('.jpg', '.jpeg', '.png')):
            continue
            
        img_path = os.path.join(folder_path, img_name)
        pose_id = img_name.split('.')[0]
        
        yaw, pitch, roll, pose_cat = estimate_head_pose_mediapipe(img_path)
        person_poses[pose_id] = {
            'yaw': yaw, 'pitch': pitch, 'roll': roll, 
            'category': pose_cat, 'expected': POSE_MAPPING.get(pose_id, 'unknown')
        }
        
        print(f"  {pose_id}.jpg: Expected={POSE_MAPPING.get(pose_id, 'unknown'):>8} | "
              f"Detected={pose_cat:>8} | Yaw={yaw:>6.1f}° Pitch={pitch:>6.1f}° Roll={roll:>6.1f}°" 
              if yaw is not None else f"  {pose_id}.jpg: Pose detection failed")
        
        emb = extract_embedding(img_path)
        pose_embeddings[pose_id] = emb
        person_images[pose_id] = Image.open(img_path)
    
    gallery_dict[person_id] = pose_embeddings
    gallery_images[person_id] = person_images
    gallery_poses[person_id] = person_poses

print(f"\nGallery built with {len(gallery_dict)} people")

# Process queries
queries_path = os.path.join(BASE_PATH, 'queries')

if not os.path.exists(queries_path):
    raise FileNotFoundError(f"Queries folder not found: {queries_path}")

query_files = sorted([f for f in os.listdir(queries_path) 
                     if f.lower().endswith(('.jpg', '.jpeg', '.png'))])

print("\n" + "="*80)
print(f"PROCESSING ALL QUERIES WITH THRESHOLD = {SIMILARITY_THRESHOLD}")
print("="*80)

results = []
pose_specific_results = {pose: [] for pose in POSE_MAPPING.values()}
pose_specific_results['unknown'] = []

for query_name in query_files:
    query_path = os.path.join(queries_path, query_name)
    
    yaw, pitch, roll, estimated_pose = estimate_head_pose_mediapipe(query_path)
    query_emb = extract_embedding(query_path)
    
    parts = query_name.split('_')
    if len(parts) >= 2:
        potential_id = parts[1].split('.')[0]
        if potential_id == '07':
            ground_truth = 'unknown'
        elif potential_id in gallery_ids:
            ground_truth = potential_id
        else:
            ground_truth = 'unknown'
    else:
        ground_truth = 'unknown'
    
    is_known = ground_truth in gallery_ids
    
    all_scores = pose_specific_score(query_emb, gallery_dict, estimated_pose)
    
    # Baseline prediction
    predicted_overall_id, score_overall = get_best_match(all_scores, 'overall')
    if meets_threshold(score_overall):
        final_prediction_overall = predicted_overall_id
        meets_threshold_overall = True
    else:
        final_prediction_overall = 'unknown'
        meets_threshold_overall = False
    
    # Pose-aware prediction
    predicted_pose_id, score_pose = get_best_match(all_scores, 'pose_specific')
    if meets_threshold(score_pose):
        final_prediction_pose = predicted_pose_id
        meets_threshold_pose = True
    else:
        final_prediction_pose = 'unknown'
        meets_threshold_pose = False
    
    result = {
        'query': query_name,
        'ground_truth': ground_truth,
        'is_known': is_known,
        'yaw': yaw,
        'pitch': pitch,
        'roll': roll,
        'estimated_pose': estimated_pose,
        'predicted_overall': final_prediction_overall,
        'score_overall': score_overall,
        'meets_threshold_overall': meets_threshold_overall,
        'predicted_pose_specific': final_prediction_pose,
        'score_pose_specific': score_pose,
        'meets_threshold_pose': meets_threshold_pose,
        'best_match_id': predicted_pose_id,
        'best_match_score': score_pose,
        'correct_overall': (final_prediction_overall == ground_truth),
        'correct_pose_specific': (final_prediction_pose == ground_truth),
        'all_scores': all_scores
    }
    
    results.append(result)
    pose_specific_results[estimated_pose].append(result)

print(f"\nProcessed {len(results)} queries")

# Generate results table
print("\n" + "="*80)
print("RESULTS TABLE - ALL QUERIES")
print("="*80)

results_table = []
for r in results:
    row = {
        'filename': r['query'],
        'true_id': r['ground_truth'],
        'pred_id': r['predicted_pose_specific'],
        'score': r['score_pose_specific'],
        'correct': 1 if r['correct_pose_specific'] else 0
    }
    results_table.append(row)

print(f"\n{'filename':<20} {'true_id':<10} {'pred_id':<10} {'score':<12} {'correct':<10}")
print("-"*80)

for row in results_table:
    print(f"{row['filename']:<20} {row['true_id']:<10} {row['pred_id']:<10} {row['score']:<12.5f} {row['correct']:<10}")

# Save to CSV
csv_path = os.path.join(BASE_PATH, f'results_table_{"L2" if USE_L2_DISTANCE else "cosine"}.csv')
with open(csv_path, 'w', newline='') as csvfile:
    fieldnames = ['filename', 'true_id', 'pred_id', 'score', 'correct']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
    for row in results_table:
        writer.writerow(row)

print(f"\n✓ Results table saved to: {csv_path}")

correct_count = sum(row['correct'] for row in results_table)
total_count = len(results_table)
accuracy = correct_count / total_count if total_count > 0 else 0

print(f"\nTable Summary:")
print(f"  Total Queries: {total_count}")
print(f"  Correct: {correct_count}")
print(f"  Incorrect: {total_count - correct_count}")
print(f"  Accuracy: {accuracy:.3f} ({100*accuracy:.1f}%)")

print(f"\nBreakdown by True ID:")
print(f"  {'ID':<10} {'Total':<10} {'Correct':<10} {'Accuracy':<10}")
print("-"*40)

for class_id in sorted(set(row['true_id'] for row in results_table)):
    class_rows = [r for r in results_table if r['true_id'] == class_id]
    class_total = len(class_rows)
    class_correct = sum(r['correct'] for r in class_rows)
    class_acc = class_correct / class_total if class_total > 0 else 0
    print(f"  {class_id:<10} {class_total:<10} {class_correct:<10} {class_acc:<10.3f}")

# Comparison section
known_results = [r for r in results if r['is_known']]
unknown_results = [r for r in results if not r['is_known']]
truly_unknown_results = unknown_results

print("\n" + "="*80)
print("BASELINE vs POSE-AWARE COMPARISON")
print("="*80)

known_results_temp = [r for r in results if r['is_known']]

baseline_correct = sum(1 for r in known_results_temp if r['correct_overall'])
pose_correct = sum(1 for r in known_results_temp if r['correct_pose_specific'])
total_known = len(known_results_temp)

print(f"\n{'Method':<20} {'Correct':<15} {'Accuracy':<15}")
print("-"*80)
print(f"{'BASELINE':<20} {baseline_correct}/{total_known:<14} {100*baseline_correct/total_known if total_known > 0 else 0:>6.1f}%")
print(f"{'POSE-AWARE':<20} {pose_correct}/{total_known:<14} {100*pose_correct/total_known if total_known > 0 else 0:>6.1f}%")
print("-"*80)

improvement = pose_correct - baseline_correct
improvement_pct = 100 * improvement / total_known if total_known > 0 else 0
print(f"{'IMPROVEMENT':<20} {'+' if improvement >= 0 else ''}{improvement:<14} {'+' if improvement_pct >= 0 else ''}{improvement_pct:>6.1f}%")

print("\n" + "="*80)
print("ANALYSIS COMPLETE")
print("="*80)
print(f"\nMetric Used: {'L2 Distance' if USE_L2_DISTANCE else 'Cosine Similarity'}")
print(f"Threshold: {SIMILARITY_THRESHOLD}")
print(f"{'Lower is better' if USE_L2_DISTANCE else 'Higher is better'}")
print(f"\nResults saved to: {csv_path}")

# Cleanup
face_mesh.close()