# Task 4: Matching Pipeline
## Implement cosine similarity matching and evaluate identification accuracy

### 1. Import Libraries

In [None]:
import cv2
import numpy as np
from pathlib import Path
import sys
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd

# Add src to path
sys.path.insert(0, str(Path('..').resolve()))
from src.embedding import FaceEmbedder, EmbeddingDatabase
from src.matching import FaceRecognitionMatcher, MatchingEvaluator

print("Libraries imported successfully!")

### 2. Setup Paths and Load Data

In [None]:
# Define paths
BASE_PATH = Path('..').resolve()
VALIDATION_PATH = BASE_PATH / 'data' / 'validation'
DB_PATH = BASE_PATH / 'data' / 'embeddings.db'

print(f"Base path: {BASE_PATH}")
print(f"Validation path: {VALIDATION_PATH}")
print(f"Database path: {DB_PATH}")
print(f"Database exists: {DB_PATH.exists()}")

# Load database
db = EmbeddingDatabase(str(DB_PATH))
stats = db.get_db_stats()

print(f"\nDatabase loaded:")
print(f"  Identities: {stats['total_identities']}")
print(f"  Embeddings: {stats['total_embeddings']}")

# Load validation images
val_images = sorted(list(VALIDATION_PATH.glob('*/*.jpg')) + list(VALIDATION_PATH.glob('*/*.png')))
print(f"\nValidation images found: {len(val_images)}")
if val_images:
    print(f"Sample: {val_images[0]}")

### 3. Load Gallery Embeddings

In [None]:
# Load all gallery embeddings from database
all_embeddings = db.get_all_embeddings()

print(f"Gallery embeddings loaded:")
print(f"  Identities: {len(all_embeddings)}")

total_embeddings = sum(len(v) for v in all_embeddings.values())
print(f"  Total embeddings: {total_embeddings}")

for identity, embeddings in list(all_embeddings.items())[:3]:
    print(f"  {identity}: {len(embeddings)} embeddings, shape {embeddings.shape}")

### 4. Initialize Matcher

In [None]:
# Initialize matcher with default threshold and top-K
matcher = FaceRecognitionMatcher(
    embeddings_dict=all_embeddings,
    threshold=0.6,
    top_k=5
)

print(f"Matcher initialized!")
print(f"  Gallery size: {matcher.num_gallery} embeddings")
print(f"  Threshold: {matcher.threshold}")
print(f"  Top-K: {matcher.top_k}")

### 5. Extract Validation Embeddings

In [None]:
# Initialize embedder for validation images
embedder = FaceEmbedder(model_name='vggface2')

# Extract embeddings for validation images
val_embeddings = []
val_ground_truth = []
failed_images = []

print(f"Extracting validation embeddings...\n")

for val_img_path in tqdm(val_images, desc="Processing validation images"):
    try:
        # Load image
        img = cv2.imread(str(val_img_path))
        if img is None:
            failed_images.append(str(val_img_path))
            continue
        
        # Extract embedding
        embedding = embedder.extract_embedding(img)
        
        # Extract ground truth from path
        identity_name = val_img_path.parent.name
        
        val_embeddings.append(embedding)
        val_ground_truth.append(identity_name)
    
    except Exception as e:
        failed_images.append(str(val_img_path))

val_embeddings = np.array(val_embeddings)

print(f"\n‚úì Extraction complete!")
print(f"  Processed: {len(val_embeddings)} images")
print(f"  Failed: {len(failed_images)}")
print(f"  Embedding shape: {val_embeddings.shape}")

### 6. Match Validation Images

In [None]:
# Match all validation embeddings
print("Matching validation images...\n")

matches = matcher.match_batch(val_embeddings)

print(f"‚úì Matching complete!")

# Extract predictions
predictions = [m['identity'] for m in matches]
confidences = [m['confidence'] for m in matches]

# Show sample results
print(f"\nSample matching results (first 5):")
for i in range(min(5, len(matches))):
    match = matches[i]
    gt = val_ground_truth[i]
    correct = "‚úì" if match['identity'] == gt else "‚úó"
    print(f"{correct} Predicted: {match['identity']}, Ground Truth: {gt}, Confidence: {match['confidence']:.4f}")

### 7. Calculate Identification Accuracy

In [None]:
# Compute overall accuracy
accuracy = matcher.get_identification_accuracy(val_embeddings, val_ground_truth)

print("=" * 60)
print("IDENTIFICATION ACCURACY")
print("=" * 60)
print(f"Top-1 Accuracy: {accuracy['top_1_accuracy']:.4f} ({accuracy['top_1_accuracy']*100:.2f}%)")
print(f"  Correct: {accuracy['top_1_correct']}/{accuracy['total_samples']}")
print(f"\nTop-5 Accuracy: {accuracy['top_5_accuracy']:.4f} ({accuracy['top_5_accuracy']*100:.2f}%)")
print(f"  Correct: {accuracy['top_5_correct']}/{accuracy['total_samples']}")
print("=" * 60)

### 8. Per-Identity Accuracy

In [None]:
# Compute per-identity accuracy
per_identity = matcher.get_per_identity_accuracy(val_embeddings, val_ground_truth)

# Convert to dataframe
identity_data = []
for identity, stats in sorted(per_identity.items()):
    identity_data.append({
        'Identity': identity,
        'Total': stats['total'],
        'Top-1 Correct': stats['top_1_correct'],
        'Top-1 Accuracy': stats['top_1_accuracy'],
        'Top-5 Correct': stats['top_5_correct'],
        'Top-5 Accuracy': stats['top_5_accuracy']
    })

df_per_identity = pd.DataFrame(identity_data)

print("Per-Identity Results:")
print(df_per_identity.to_string(index=False))

print(f"\nStatistics:")
print(f"  Mean Top-1 Accuracy: {df_per_identity['Top-1 Accuracy'].mean():.4f}")
print(f"  Min Top-1 Accuracy: {df_per_identity['Top-1 Accuracy'].min():.4f}")
print(f"  Max Top-1 Accuracy: {df_per_identity['Top-1 Accuracy'].max():.4f}")

### 9. Threshold Analysis

In [None]:
# Analyze different thresholds
threshold_results = MatchingEvaluator.threshold_analysis(
    confidences=confidences,
    ground_truth=val_ground_truth,
    predictions=predictions,
    thresholds=np.linspace(0.3, 1.0, 15)
)

df_thresholds = pd.DataFrame(threshold_results)

print("Threshold Analysis:")
print(df_thresholds.to_string(index=False))

# Find optimal threshold (max accuracy with good coverage)
optimal_idx = df_thresholds['accuracy'].idxmax()
optimal_threshold = df_thresholds.loc[optimal_idx, 'threshold']
optimal_accuracy = df_thresholds.loc[optimal_idx, 'accuracy']
optimal_coverage = df_thresholds.loc[optimal_idx, 'coverage']

print(f"\nOptimal Threshold: {optimal_threshold:.2f}")
print(f"  Accuracy: {optimal_accuracy:.4f}")
print(f"  Coverage: {optimal_coverage:.4f}")

### 10. Visualization

In [None]:
# Create visualizations
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# 1. Per-identity top-1 accuracy
axes[0, 0].barh(df_per_identity['Identity'], df_per_identity['Top-1 Accuracy'])
axes[0, 0].set_xlabel('Top-1 Accuracy')
axes[0, 0].set_title('Per-Identity Top-1 Accuracy')
axes[0, 0].set_xlim([0, 1.1])
axes[0, 0].grid(alpha=0.3, axis='x')

# 2. Confidence distribution
axes[0, 1].hist(confidences, bins=15, edgecolor='black')
axes[0, 1].axvline(matcher.threshold, color='r', linestyle='--', linewidth=2, label=f'Current: {matcher.threshold}')
axes[0, 1].set_xlabel('Confidence Score')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Confidence Distribution')
axes[0, 1].legend()
axes[0, 1].grid(alpha=0.3)

# 3. Threshold vs Accuracy/Coverage
axes[1, 0].plot(df_thresholds['threshold'], df_thresholds['accuracy'], 'b-o', label='Accuracy', linewidth=2)
axes[1, 0].plot(df_thresholds['threshold'], df_thresholds['coverage'], 'r-s', label='Coverage', linewidth=2)
axes[1, 0].axvline(optimal_threshold, color='g', linestyle='--', linewidth=2, label=f'Optimal: {optimal_threshold:.2f}')
axes[1, 0].set_xlabel('Threshold')
axes[1, 0].set_ylabel('Rate')
axes[1, 0].set_title('Threshold Analysis')
axes[1, 0].legend()
axes[1, 0].grid(alpha=0.3)

# 4. Accuracy summary
axes[1, 1].axis('off')
summary_text = f"""
MATCHING SUMMARY
{'='*40}

Overall Performance:
  Top-1 Accuracy: {accuracy['top_1_accuracy']*100:.2f}%
  Top-5 Accuracy: {accuracy['top_5_accuracy']*100:.2f}%

Threshold: {matcher.threshold}
Top-K: {matcher.top_k}
Gallery Size: {matcher.num_gallery} embeddings
Test Size: {len(val_embeddings)} images

Confidence Stats:
  Min: {np.min(confidences):.4f}
  Max: {np.max(confidences):.4f}
  Mean: {np.mean(confidences):.4f}
  Std: {np.std(confidences):.4f}

Optimal Threshold: {optimal_threshold:.2f}
  Accuracy: {optimal_accuracy*100:.2f}%
  Coverage: {optimal_coverage*100:.2f}%
"""
axes[1, 1].text(0.05, 0.95, summary_text, fontsize=11, family='monospace',
                verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

plt.tight_layout()
plt.show()

print("‚úì Visualization complete")

### 11. Detailed Matching Examples

In [None]:
# Show detailed matches for all validation images
print("\n" + "="*80)
print("DETAILED MATCHING RESULTS")
print("="*80)

correct_count = 0
for i, (match, gt) in enumerate(zip(matches, val_ground_truth), 1):
    is_correct = match['identity'] == gt
    if is_correct:
        correct_count += 1
    
    status = "‚úì CORRECT" if is_correct else "‚úó WRONG"
    print(f"\nImage {i}: {status}")
    print(f"  Ground Truth: {gt}")
    print(f"  Top-1 Match: {match['identity']} (confidence: {match['confidence']:.4f})")
    print(f"  Matched: {'Yes' if match['matched'] else 'No'}")
    print(f"  Inference time: {match['inference_time_ms']:.2f} ms")
    
    print(f"  Top-5 Matches:")
    for j, top_match in enumerate(match['top_k_matches'][:5], 1):
        print(f"    {j}. {top_match['identity']}: {top_match['confidence']:.4f}")

print(f"\n" + "="*80)
print(f"Summary: {correct_count}/{len(val_ground_truth)} correct ({correct_count/len(val_ground_truth)*100:.2f}%)")
print("="*80)

### 12. Summary

In [None]:
print("\n" + "="*60)
print("TASK 4: MATCHING PIPELINE - SUMMARY")
print("="*60)
print(f"\n‚úì Matcher: Cosine Similarity")
print(f"‚úì Gallery: {matcher.num_gallery} embeddings from {len(all_embeddings)} identities")
print(f"\nüìä RESULTS:")
print(f"   Top-1 Accuracy: {accuracy['top_1_accuracy']:.4f} ({accuracy['top_1_accuracy']*100:.2f}%)")
print(f"   Top-5 Accuracy: {accuracy['top_5_accuracy']:.4f} ({accuracy['top_5_accuracy']*100:.2f}%)")
print(f"\n‚öôÔ∏è CONFIGURATION:")
print(f"   Threshold: {matcher.threshold}")
print(f"   Top-K: {matcher.top_k}")
print(f"\nüí° RECOMMENDATION:")
print(f"   Optimal Threshold: {optimal_threshold:.2f} (accuracy: {optimal_accuracy*100:.2f}%)")
print(f"\n" + "="*60)