In [1]:
from tensorflow.keras.applications import efficientnet_v2
import tensorflow as tf
from tensorflow.keras.models import load_model
import numpy as np
from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt

### Processes an input image for a given model type in our case contrastive or triplet  by resizing and applying preprocessing


In [None]:
def process_image(img_source, target_size=(128, 128), preTrained=False):
    img = Image.open(img_source) if isinstance(img_source, str) else img_source
    img = img.convert('RGB')

    img_array = np.array(img)
    if not isinstance(img_array, tf.Tensor):
        img_array = tf.convert_to_tensor(img_array)
    img_array = tf.image.resize(img_array, target_size)

    if preTrained:
        return efficientnet_v2.preprocess_input(img_array)
    else:
        return (img_array / 127.5) - 1.0

### Calculates the similarity between two images


In [None]:
def get_similarity(model, img1, img2, preTrained, model_type='contrastive', embedding_network_name='EmbeddingNetwork'):
    img1_processed = process_image(img1, preTrained=preTrained)
    img2_processed = process_image(img2, preTrained=preTrained)

    if model_type == 'contrastive':
        distance = model.predict(
            [np.expand_dims(img1_processed, 0), np.expand_dims(img2_processed, 0)],
            verbose=0
        )[0][0]
        similarity = 1.0 / (1.0 + distance)
        return similarity
    else:
        embedding_network = model.get_layer(embedding_network_name)
        emb1 = embedding_network.predict(np.expand_dims(img1_processed, 0), verbose=0)
        emb2 = embedding_network.predict(np.expand_dims(img2_processed, 0), verbose=0)
        similarity = np.dot(emb1[0], emb2[0])
        return (similarity + 1) / 2

### Verifies if two face images match by computing their similarity and comparing it to a threshold

In [None]:
def verify_faces(model, img1_path, img2_path, preTrained, model_type='contrastive',
                 threshold=0.75, embedding_network_name='EmbeddingNetwork'):

    similarity = get_similarity(
        model, img1_path, img2_path,
        preTrained,
        model_type=model_type,
        embedding_network_name=embedding_network_name
    )

    is_match = similarity >= threshold

    return {
        'is_match': is_match,
        'similarity': float(similarity),
        'image1': img1_path,
        'image2': img2_path
    }

### Identifies matching faces by comparing a probe image against a gallery of images returning top matches above a similarity threshold


In [None]:
def identify_faces(model, probe_image, gallery_folder, preTrained, model_type='contrastive',
                   threshold=0.2, top_k=5, embedding_network_name='EmbeddingNetwork'):
    # Get all images
    valid_extensions = {'.jpg', '.jpeg', '.webp'}
    gallery_images = [
        str(f) for f in Path(gallery_folder).iterdir()
        if f.suffix.lower() in valid_extensions
    ]

    if not gallery_images:
        return {
            'probe_image': probe_image,
            'matches': [],
            'match_found': False
        }

    similarities = []

    for gallery_img in gallery_images:
        similarity = get_similarity(
            model, probe_image, gallery_img,
            preTrained,
            model_type=model_type,
            embedding_network_name=embedding_network_name
        )
        similarities.append(similarity)

    similarities = np.array(similarities)
    top_indices = np.argsort(similarities)[::-1][:top_k]
    top_similarities = similarities[top_indices]

    matches = []
    for idx, sim in zip(top_indices, top_similarities):
        if sim >= threshold:
            matches.append({
                'gallery_image': gallery_images[idx],
                'similarity': float(sim)
            })

    return {
        'probe_image': probe_image,
        'matches': matches,
        'match_found': len(matches) > 0
    }

### Loads an image from disk and converts it into a numpy array for display

In [25]:
def load_display_image(img_path):
    img = Image.open(img_path)
    return np.array(img)

### Displays face verification results

In [None]:
def visualize_verification_results(verification_results, save_path=None):
    plt.style.use('default')
    fig = plt.figure(figsize=(12, 6))

    gs = plt.GridSpec(2, 2, figure=fig)
    ax1 = fig.add_subplot(gs[0, 0])
    ax2 = fig.add_subplot(gs[0, 1])
    ax3 = fig.add_subplot(gs[1, :])

    img1 = load_display_image(verification_results['image1'])
    img2 = load_display_image(verification_results['image2'])

    ax1.imshow(img1)
    ax1.set_title('Image 1', pad=10)
    ax1.axis('off')

    ax2.imshow(img2)
    ax2.set_title('Image 2', pad=10)
    ax2.axis('off')

    similarity = verification_results['similarity']
    ax3.barh(y=0, width=100, height=0.3, color='lightgray')
    bar_color = 'green' if verification_results['is_match'] else 'red'
    display_width = min(100, max(0, similarity * 100))
    ax3.barh(y=0, width=display_width, height=0.3, color=bar_color)

    ax3.text(50, 0.5, f"Similarity: {similarity:.3f}",
             ha='center', va='bottom', fontsize=12)
    match_text = "MATCH" if verification_results['is_match'] else "NO MATCH"
    match_color = "green" if verification_results['is_match'] else "red"
    ax3.text(50, -0.5, match_text, ha='center', va='top',
             fontsize=14, fontweight='bold', color=match_color)

    ax3.set_xlim(-5, 105)
    ax3.set_ylim(-1, 1)
    ax3.axis('off')

    plt.suptitle("Face Verification Results", fontsize=14, y=0.95)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

### Visualizes face identification results

In [4]:
def visualize_identification_results(identification_results, save_path=None):
    plt.style.use('default')
    n_matches = len(identification_results['matches'])
    if n_matches == 0:
        plt.figure(figsize=(8, 4))
        plt.text(0.5, 0.5, "NO MATCHES FOUND",
                 ha='center', va='center', fontsize=20, color='red')
        plt.axis('off')
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
        return

    n_cols = min(3, n_matches + 1)
    n_rows = (n_matches + 2) // n_cols
    fig = plt.figure(figsize=(4*n_cols, 4*n_rows))

    ax_probe = plt.subplot(n_rows, n_cols, 1)
    probe_img = load_display_image(identification_results['probe_image'])
    ax_probe.imshow(probe_img)
    ax_probe.set_title('Probe Image', fontsize=12, pad=10)
    ax_probe.axis('off')

    for idx, match in enumerate(identification_results['matches'], 2):
        ax = plt.subplot(n_rows, n_cols, idx)
        gallery_img = load_display_image(match['gallery_image'])
        ax.imshow(gallery_img)
        score_color = 'green' if match['similarity'] > 0.7 else 'orange'
        ax.set_title(f"Match {idx-1}\nSimilarity: {match['similarity']:.3f}",
                     fontsize=12, pad=10, color=score_color)
        ax.axis('off')

    plt.suptitle("Face Identification Results", fontsize=16, y=1.02)
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
# load the model
v1 = load_model('../results/siamese/contrastive_v1/contrastive_v1.h5', compile=False)
v1.compile()

### Verification example

In [None]:
verification_results = verify_faces(
    model=v1,
    img1_path='path/to/person1_photo1.jpg',
    img2_path='path/to/person1_photo2.jpg',
    preTrained=False,  # Set to True if using EfficientNet model
    model_type='contrastive',  # or triplet
    threshold=0.75  # Adjust threshold 
)


In [None]:
# Visualize the verification results
visualize_verification_results(verification_results)

### identification example

In [None]:
# Search for matching faces in a gallery
identification_results = identify_faces(
    model=v1,
    probe_image='path/to/probe_image.jpg',
    gallery_folder='path/to/gallery_folder',
    preTrained=False, # Set to True if using EfficientNet model
    model_type='contrastive',
    threshold=0.75, # or triplet
    top_k=5  # Return top k (5 or ...) matches
)

In [None]:
# Visualize the identification results
visualize_identification_results(identification_results)