In [None]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from typing import Tuple, List, Optional

class InceptionBlock(nn.Module):
    """Implementation of Inception block used in FaceNet"""
    
    def __init__(
        self,
        in_channels: int,
        ch1x1: int,
        ch3x3red: int,
        ch3x3: int,
        ch5x5red: int,
        ch5x5: int,
        pool_proj: int
    ):
        super().__init__()
        
        # 1x1 conv branch
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, ch1x1, kernel_size=1),
            nn.BatchNorm2d(ch1x1),
            nn.ReLU()
        )
        
        # 1x1 conv -> 3x3 conv branch
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, ch3x3red, kernel_size=1),
            nn.BatchNorm2d(ch3x3red),
            nn.ReLU(),
            nn.Conv2d(ch3x3red, ch3x3, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch3x3),
            nn.ReLU()
        )
        
        # 1x1 conv -> 5x5 conv branch (implemented as two 3x3 convs)
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, ch5x5red, kernel_size=1),
            nn.BatchNorm2d(ch5x5red),
            nn.ReLU(),
            nn.Conv2d(ch5x5red, ch5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU(),
            nn.Conv2d(ch5x5, ch5x5, kernel_size=3, padding=1),
            nn.BatchNorm2d(ch5x5),
            nn.ReLU()
        )
        
        # Max pool -> 1x1 conv branch
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, pool_proj, kernel_size=1),
            nn.BatchNorm2d(pool_proj),
            nn.ReLU()
        )
    
    def __call__(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        
        # Concatenate along channel dimension
        return mx.concatenate([branch1, branch2, branch3, branch4], axis=1)

class FaceNet(nn.Module):
    """FaceNet implementation in MLX"""
    
    def __init__(self, embedding_size: int = 128):
        super().__init__()
        
        # Initial convolutional layers
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # Inception blocks
        self.inception1 = InceptionBlock(64, 64, 96, 128, 16, 32, 32)
        self.inception2 = InceptionBlock(256, 128, 128, 192, 32, 96, 64)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.inception3a = InceptionBlock(480, 192, 96, 208, 16, 48, 64)
        self.inception3b = InceptionBlock(512, 160, 112, 224, 24, 64, 64)
        self.inception3c = InceptionBlock(512, 128, 128, 256, 24, 64, 64)
        self.inception4a = InceptionBlock(512, 112, 144, 288, 32, 64, 64)
        self.inception4b = InceptionBlock(528, 256, 160, 320, 32, 128, 128)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.inception4c = InceptionBlock(832, 256, 160, 320, 32, 128, 128)
        self.inception4d = InceptionBlock(832, 384, 192, 384, 48, 128, 128)
        
        # Final layers
        self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, embedding_size)
        self.bn2 = nn.BatchNorm1d(embedding_size)
        
        # L2 normalization
        self.l2_norm = lambda x: x / mx.sqrt(mx.sum(mx.square(x), axis=1, keepdims=True))
    
    def __call__(self, x, training: bool = False):
        # Initial convolutional layers
        x = self.maxpool1(self.relu(self.bn1(self.conv1(x))))
        
        # Inception blocks
        x = self.inception1(x)
        x = self.inception2(x)
        x = self.maxpool2(x)
        
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.inception3c(x)
        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.maxpool3(x)
        
        x = self.inception4c(x)
        x = self.inception4d(x)
        
        # Final layers
        x = self.avgpool(x)
        x = mx.reshape(x, (-1, 1024))
        if training:
            x = self.dropout(x)
        x = self.fc(x)
        x = self.bn2(x)
        
        # L2 normalization to create the embedding
        embedding = self.l2_norm(x)
        return embedding

class TripletLoss(nn.Module):
    """Triplet loss with hard positive/negative mining"""
    
    def __init__(self, margin: float = 0.2):
        super().__init__()
        self.margin = margin
    
    def __call__(self, embeddings, labels):
        # Get the pairwise distance matrix
        pairwise_dist = self._pairwise_distances(embeddings)
        
        # For each anchor, get the hardest positive
        # First, get a mask for valid positive pairs (same class)
        labels = mx.array(labels)
        mask_positives = mx.equal(mx.expand_dims(labels, axis=0), mx.expand_dims(labels, axis=1))
        
        # Exclude the diagonal from mask_positives (distance to self is 0, not a useful positive)
        mask_positives = mask_positives & (1 - mx.eye(labels.shape[0], dtype=mx.bool_))
        
        # Get hardest positives (maximum distance to positive samples)
        hardest_positive_dist = mx.max(pairwise_dist * mask_positives.astype(mx.float32), axis=1)
        
        # For each anchor, get the hardest negative
        # First, get a mask for valid negative pairs (different class)
        mask_negatives = ~mask_positives
        
        # Make invalid negatives have large distance so they won't be selected
        max_dist = mx.max(pairwise_dist)
        neg_dist = pairwise_dist * mask_negatives.astype(mx.float32) + max_dist * (~mask_negatives).astype(mx.float32)
        
        # Get hardest negatives (minimum distance to negative samples)
        hardest_negative_dist = mx.min(neg_dist, axis=1)
        
        # Calculate triplet loss
        triplet_loss = mx.maximum(hardest_positive_dist - hardest_negative_dist + self.margin, 0.0)
        
        # Get final mean triplet loss
        return mx.mean(triplet_loss)
    
    def _pairwise_distances(self, embeddings):
        """Compute the 2D matrix of distances between all embeddings."""
        # Get dot product (batch_size, batch_size)
        dot_product = mx.matmul(embeddings, mx.transpose(embeddings))
        
        # Get squared L2 norm for each embedding (batch_size, 1)
        square_norm = mx.sum(mx.square(embeddings), axis=1)
        
        # Calculate pairwise distance matrix 
        # ||a - b||^2 = ||a||^2 + ||b||^2 - 2 * a.dot(b)
        distances = mx.expand_dims(square_norm, 1) + mx.expand_dims(square_norm, 0) - 2.0 * dot_product
        
        # Because of computation errors, some distances might be negative
        distances = mx.maximum(distances, 0.0)
        
        # If we want the actual distance, we can use sqrt
        return mx.sqrt(distances)

# Data preprocessing function
def preprocess_image(image_path: str, target_size: Tuple[int, int] = (160, 160)) -> mx.array:
    """
    Load and preprocess an image for FaceNet.
    
    Args:
        image_path: Path to the image file
        target_size: Target size for the image (height, width)
        
    Returns:
        Preprocessed image as MLX array
    """
    try:
        import cv2
        # Load image
        img = cv2.imread(image_path)
        if img is None:
            raise ValueError(f"Unable to load image from {image_path}")
        
        # Convert BGR to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Resize to target size
        img = cv2.resize(img, target_size[::-1])  # CV2 takes (width, height)
        
        # Normalize pixel values to [-1, 1]
        img = img.astype(np.float32) / 127.5 - 1.0
        
        # Convert to MLX array and add batch dimension
        return mx.array(img).transpose(2, 0, 1)  # Convert to channels-first format
    
    except Exception as e:
        print(f"Error preprocessing image {image_path}: {e}")
        return None

# Example of loading pretrained weights from a PyTorch model
def load_pretrained_weights(facenet_mlx: FaceNet, torch_model_path: str) -> FaceNet:
    """
    Load pretrained weights from a PyTorch FaceNet model.
    
    Args:
        facenet_mlx: MLX FaceNet model
        torch_model_path: Path to PyTorch model file
        
    Returns:
        MLX FaceNet model with loaded weights
    """
    try:
        import torch
        
        # Load PyTorch model
        torch_model = torch.load(torch_model_path, map_location="cpu")
        
        # Create a state dict for MLX model
        state_dict = {}
        
        # Map PyTorch parameter names to MLX parameter names
        # This is a simplified example - you would need to adapt based on your specific models
        for name, param in torch_model.items():
            # Convert PyTorch tensor to numpy array
            param_np = param.numpy()
            
            # Handle convolution weight format differences (PyTorch uses OIHW, MLX uses OIHW)
            if len(param_np.shape) == 4:
                # No need to transpose for this case, format is the same
                pass
            
            # Add to MLX state dict - you might need to adjust the naming convention here
            # based on your specific PyTorch model
            state_dict[name] = mx.array(param_np)
        
        # Load state dict into MLX model
        facenet_mlx.load_weights(state_dict)
        
        print("Loaded pretrained weights from PyTorch model")
        return facenet_mlx
    
    except Exception as e:
        print(f"Error loading pretrained weights: {e}")
        return facenet_mlx

# Example usage for face recognition
def recognize_faces(
    model: FaceNet,
    unknown_face_path: str,
    reference_faces: List[Tuple[str, str]],
    threshold: float = 0.7
) -> List[Tuple[str, float]]:
    """
    Recognize faces by comparing an unknown face against reference faces.
    
    Args:
        model: Trained FaceNet model
        unknown_face_path: Path to unknown face image
        reference_faces: List of (person_name, face_image_path) tuples
        threshold: Similarity threshold (cosine similarity)
        
    Returns:
        List of (person_name, similarity) tuples for matches above threshold
    """
    # Preprocess unknown face
    unknown_face = preprocess_image(unknown_face_path)
    if unknown_face is None:
        return []
    
    # Get embedding for unknown face
    unknown_embedding = model(mx.expand_dims(unknown_face, 0))[0]
    
    matches = []
    
    # Compare with reference faces
    for name, face_path in reference_faces:
        # Preprocess reference face
        ref_face = preprocess_image(face_path)
        if ref_face is None:
            continue
        
        # Get embedding for reference face
        ref_embedding = model(mx.expand_dims(ref_face, 0))[0]
        
        # Calculate cosine similarity
        similarity = mx.sum(unknown_embedding * ref_embedding) / (
            mx.sqrt(mx.sum(unknown_embedding ** 2)) * mx.sqrt(mx.sum(ref_embedding ** 2))
        )
        
        # Check if similarity is above threshold
        similarity_value = similarity.item()
        if similarity_value > threshold:
            matches.append((name, similarity_value))
    
    # Sort matches by similarity (highest first)
    return sorted(matches, key=lambda x: x[1], reverse=True)

# Training function
def train_facenet(
    model: FaceNet,
    train_data_path: str,
    val_data_path: Optional[str] = None,
    epochs: int = 10,
    batch_size: int = 32,
    learning_rate: float = 0.001,
    margin: float = 0.2
):
    """
    Train FaceNet model using triplet loss.
    
    Args:
        model: FaceNet model
        train_data_path: Path to training data directory
        val_data_path: Path to validation data directory (optional)
        epochs: Number of training epochs
        batch_size: Batch size
        learning_rate: Learning rate
        margin: Margin for triplet loss
        
    Returns:
        Trained model
    """
    # Set up optimizer
    optimizer = optim.Adam(learning_rate=learning_rate)
    
    # Set up loss function
    triplet_loss = TripletLoss(margin=margin)
    
    # Define training step
    @mx.compile
    def train_step(model, images, labels):
        def loss_fn(model):
            embeddings = model(images, training=True)
            loss = triplet_loss(embeddings, labels)
            return loss
        
        loss, grads = nn.value_and_grad(loss_fn)(model)
        optimizer.update(model, grads)
        return loss
    
    # TODO: Implement data loading and training loop
    # This would require implementing a DataLoader for MLX
    # For brevity, only the core training step is shown
    
    print("Training function defined - actual implementation would require a DataLoader")
    return model

# Main execution example
def main():
    # Initialize FaceNet model
    facenet = FaceNet(embedding_size=128)
    
    # If you have pretrained weights
    # facenet = load_pretrained_weights(facenet, "path_to_pytorch_model.pth")
    
    # Example: Face recognition
    # matches = recognize_faces(
    #     model=facenet,
    #     unknown_face_path="unknown_face.jpg",
    #     reference_faces=[
    #         ("Person1", "person1_face.jpg"),
    #         ("Person2", "person2_face.jpg"),
    #     ],
    #     threshold=0.7
    # )
    
    # Print matches
    # for name, similarity in matches:
    #     print(f"Match: {name}, Similarity: {similarity:.4f}")
    
    print("FaceNet model successfully initialized")
    return facenet

if __name__ == "__main__":
    main()

In [3]:
# Install required packages
!pip install mlx opencv-python numpy matplotlib

Collecting matplotlib
  Using cached matplotlib-3.10.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (5.4 kB)
Collecting cycler>=0.10 (from matplotlib)
  Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.56.0-cp311-cp311-macosx_10_9_universal2.whl.metadata (101 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.8-cp311-cp311-macosx_11_0_arm64.whl.metadata (6.2 kB)
Collecting pillow>=8 (from matplotlib)
  Using cached pillow-11.1.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (9.1 kB)
Collecting pyparsing>=2.3.1 (from matplotlib)
  Downloading pyparsing-3.2.1-py3-none-any.whl.metadata (5.0 kB)
Downloading matplotlib-3.10.1-cp311-cp311-macosx_11_0_arm64.whl (8.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.0/8.0 MB[0m [31m63.4 MB/s

In [14]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from typing import Tuple, List, Optional
import matplotlib.pyplot as plt
import os

In [15]:
class InceptionBlock(nn.Module):
    """Implementation of Inception block used in FaceNet"""
    
    def __init__(
        self,
        in_channels: int,
        ch1x1: int,
        ch3x3red: int,
        ch3x3: int,
        ch5x5red: int,
        ch5x5: int,
        pool_proj: int
    ):
        super().__init__()
        
        # 1x1 conv branch
        self.branch1 = nn.Sequential(
            nn.Conv2d(in_channels, ch1x1, kernel_size=1),
            nn.BatchNorm(ch1x1),
            nn.ReLU()
        )
        
        # 1x1 conv -> 3x3 conv branch
        self.branch2 = nn.Sequential(
            nn.Conv2d(in_channels, ch3x3red, kernel_size=1),
            nn.BatchNorm(ch3x3red),
            nn.ReLU(),
            nn.Conv2d(ch3x3red, ch3x3, kernel_size=3, padding=1),
            nn.BatchNorm(ch3x3),
            nn.ReLU()
        )
        
        # 1x1 conv -> 5x5 conv branch (implemented as two 3x3 convs)
        self.branch3 = nn.Sequential(
            nn.Conv2d(in_channels, ch5x5red, kernel_size=1),
            nn.BatchNorm(ch5x5red),
            nn.ReLU(),
            nn.Conv2d(ch5x5red, ch5x5, kernel_size=3, padding=1),
            nn.BatchNorm(ch5x5),
            nn.ReLU(),
            nn.Conv2d(ch5x5, ch5x5, kernel_size=3, padding=1),
            nn.BatchNorm(ch5x5),
            nn.ReLU()
        )
        
        # Max pool -> 1x1 conv branch
        self.branch4 = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
            nn.Conv2d(in_channels, pool_proj, kernel_size=1),
            nn.BatchNorm(pool_proj),
            nn.ReLU()
        )
    
    def __call__(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        branch3 = self.branch3(x)
        branch4 = self.branch4(x)
        
        # Concatenate along channel dimension
        return mx.concatenate([branch1, branch2, branch3, branch4], axis=1)

In [16]:
class FaceNet(nn.Module):
    """FaceNet implementation in MLX"""
    
    def __init__(self, embedding_size: int = 128):
        super().__init__()
        
        # Initial convolutional layers
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
        self.bn1 = nn.BatchNorm(64)  # Fixed: Changed from BatchNorm2d to BatchNorm
        self.relu = nn.ReLU()
        self.maxpool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # Inception blocks
        self.inception1 = InceptionBlock(64, 64, 96, 128, 16, 32, 32)
        self.inception2 = InceptionBlock(256, 128, 128, 192, 32, 96, 64)
        self.maxpool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.inception3a = InceptionBlock(480, 192, 96, 208, 16, 48, 64)
        self.inception3b = InceptionBlock(512, 160, 112, 224, 24, 64, 64)
        self.inception3c = InceptionBlock(512, 128, 128, 256, 24, 64, 64)
        self.inception4a = InceptionBlock(512, 112, 144, 288, 32, 64, 64)
        self.inception4b = InceptionBlock(528, 256, 160, 320, 32, 128, 128)
        self.maxpool3 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.inception4c = InceptionBlock(832, 256, 160, 320, 32, 128, 128)
        self.inception4d = InceptionBlock(832, 384, 192, 384, 48, 128, 128)
        
        # Final layers
        self.avgpool = nn.AvgPool2d(kernel_size=7, stride=1)
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(1024, embedding_size)
        self.bn2 = nn.BatchNorm(embedding_size)  # Fixed: Changed from BatchNorm1d to BatchNorm
        
        # L2 normalization
        self.l2_norm = lambda x: x / mx.sqrt(mx.sum(mx.square(x), axis=1, keepdims=True))
    
    def __call__(self, x, training: bool = False):
        # Initial convolutional layers
        x = self.maxpool1(self.relu(self.bn1(self.conv1(x))))
        
        # Inception blocks
        x = self.inception1(x)
        x = self.inception2(x)
        x = self.maxpool2(x)
        
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.inception3c(x)
        x = self.inception4a(x)
        x = self.inception4b(x)
        x = self.maxpool3(x)
        
        x = self.inception4c(x)
        x = self.inception4d(x)
        
        # Final layers
        x = self.avgpool(x)
        x = mx.reshape(x, (-1, 1024))
        if training:
            x = self.dropout(x)
        x = self.fc(x)
        x = self.bn2(x)
        
        # L2 normalization to create the embedding
        embedding = self.l2_norm(x)
        return embedding

In [None]:
class TripletLoss(nn.Module):
    """Triplet loss with hard positive/negative mining"""
    
    def __init__(self, margin: float = 0.2):
        super().__init__()
        self.margin = margin
    
    def __call__(self, embeddings, labels):
        # Get the pairwise distance matrix
        pairwise_dist = self._pairwise_distances(embeddings)
        
        # For each anchor, get the hardest positive
        # First, get a mask for valid positive pairs (same class)
        labels = mx.array(labels)
        mask_positives = mx.equal(mx.expand_dims(labels, axis=0), mx.expand_dims(labels, axis=1))
        
        # Exclude the diagonal from mask_positives (distance to self is 0, not a useful positive)
        mask_positives = mask_positives & (1 - mx.eye(labels.shape[0], dtype=mx.bool_))
        
        # Get hardest positives (maximum distance to positive samples)
        hardest_positive_dist = mx.max(pairwise_dist * mask_positives.astype(mx.float32), axis=1)
        
        # For each anchor, get the hardest negative
        # First, get a mask for valid negative pairs (different class)
        mask_negatives = ~mask_positives
        
        # Make invalid negatives have large distance so they won't be selected
        max_dist = mx.max(pairwise_dist)
        neg_dist = pairwise_dist * mask_negatives.astype(mx.float32) + max_dist * (~mask_negatives).astype(mx.float32)
        
        # Get hardest negatives (minimum distance to negative samples)
        hardest_negative_dist = mx.min(neg_dist, axis=1)
        
        # Calculate triplet loss
        triplet_loss = mx.maximum(hardest_positive_dist - hardest_negative_dist + self.margin, 0.0)
        
        # Get final mean triplet loss
        return mx.mean(triplet_loss)
    
    def _pairwise_distances(self, embeddings):
        """Compute the 2D matrix of distances between all embeddings."""
        # Get dot product (batch_size, batch_size)
        dot_product = mx.matmul(embeddings, mx.transpose(embeddings))
        
        # Get squared L2 norm for each embedding (batch_size, 1)
        square_norm = mx.sum(mx.square(embeddings), axis=1)
        
        # Calculate pairwise distance matrix 
        # ||a - b||^2 = ||a||^2 + ||b||^2 - 2 * a.dot(b)
        distances = mx.expand_dims(square_norm, 1) + mx.expand_dims(square_norm, 0) - 2.0 * dot_product
        
        # Because of computation errors, some distances might be negative
        distances = mx.maximum(distances, 0.0)
        
        # If we want the actual distance, we can use sqrt
        return mx.sqrt(distances)

In [None]:
def preprocess_image(image_path: str, target_size: Tuple[int, int] = (160, 160)) -> mx.array:
    """
    Load and preprocess an image for FaceNet.
    
    Args:
        image_path: Path to the image file
        target_size: Target size for the image (height, width)
        
    Returns:
        Preprocessed image as MLX array
    """
    try:
        import cv2
        # Load image
        img = cv2.imread(image_path)
        if img is None:
            raise ValueError(f"Unable to load image from {image_path}")
        
        # Convert BGR to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Resize to target size
        img = cv2.resize(img, target_size[::-1])  # CV2 takes (width, height)
        
        # Normalize pixel values to [-1, 1]
        img = img.astype(np.float32) / 127.5 - 1.0
        
        # Convert to MLX array and add batch dimension
        return mx.array(img).transpose(2, 0, 1)  # Convert to channels-first format
    
    except Exception as e:
        print(f"Error preprocessing image {image_path}: {e}")
        return None

# Visualize preprocessed images
def visualize_preprocessed_image(image_path: str):
    """Visualize both the original and preprocessed image"""
    import cv2
    import matplotlib.pyplot as plt
    
    # Load original image
    original_img = cv2.imread(image_path)
    original_img = cv2.cvtColor(original_img, cv2.COLOR_BGR2RGB)
    
    # Get preprocessed image and convert back for visualization
    preprocessed = preprocess_image(image_path)
    # Convert from channels-first back to channels-last for display
    preprocessed_viz = preprocessed.transpose(1, 2, 0).numpy()
    # Rescale from [-1, 1] to [0, 1] for visualization
    preprocessed_viz = (preprocessed_viz + 1.0) / 2.0
    
    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
    ax1.imshow(original_img)
    ax1.set_title('Original Image')
    ax1.axis('off')
    
    ax2.imshow(preprocessed_viz)
    ax2.set_title('Preprocessed Image')
    ax2.axis('off')
    
    plt.tight_layout()
    plt.show()

In [17]:
def load_pretrained_weights(facenet_mlx: FaceNet, torch_model_path: str) -> FaceNet:
    """
    Load pretrained weights from a PyTorch FaceNet model.
    
    Args:
        facenet_mlx: MLX FaceNet model
        torch_model_path: Path to PyTorch model file
        
    Returns:
        MLX FaceNet model with loaded weights
    """
    try:
        import torch
        
        # Load PyTorch model
        torch_model = torch.load(torch_model_path, map_location="cpu")
        
        # Create a state dict for MLX model
        state_dict = {}
        
        # Map PyTorch parameter names to MLX parameter names
        # This is a simplified example - you would need to adapt based on your specific models
        for name, param in torch_model.items():
            # Convert PyTorch tensor to numpy array
            param_np = param.numpy()
            
            # Handle convolution weight format differences (PyTorch uses OIHW, MLX uses OIHW)
            if len(param_np.shape) == 4:
                # No need to transpose for this case, format is the same
                pass
            
            # Add to MLX state dict - you might need to adjust the naming convention here
            # based on your specific PyTorch model
            state_dict[name] = mx.array(param_np)
        
        # Load state dict into MLX model
        facenet_mlx.load_weights(state_dict)
        
        print("Loaded pretrained weights from PyTorch model")
        return facenet_mlx
    
    except Exception as e:
        print(f"Error loading pretrained weights: {e}")
        return facenet_mlx

In [18]:
def recognize_faces(
    model: FaceNet,
    unknown_face_path: str,
    reference_faces: List[Tuple[str, str]],
    threshold: float = 0.7
) -> List[Tuple[str, float]]:
    """
    Recognize faces by comparing an unknown face against reference faces.
    
    Args:
        model: Trained FaceNet model
        unknown_face_path: Path to unknown face image
        reference_faces: List of (person_name, face_image_path) tuples
        threshold: Similarity threshold (cosine similarity)
        
    Returns:
        List of (person_name, similarity) tuples for matches above threshold
    """
    # Preprocess unknown face
    unknown_face = preprocess_image(unknown_face_path)
    if unknown_face is None:
        return []
    
    # Get embedding for unknown face
    unknown_embedding = model(mx.expand_dims(unknown_face, 0))[0]
    
    matches = []
    
    # Compare with reference faces
    for name, face_path in reference_faces:
        # Preprocess reference face
        ref_face = preprocess_image(face_path)
        if ref_face is None:
            continue
        
        # Get embedding for reference face
        ref_embedding = model(mx.expand_dims(ref_face, 0))[0]
        
        # Calculate cosine similarity
        similarity = mx.sum(unknown_embedding * ref_embedding) / (
            mx.sqrt(mx.sum(unknown_embedding ** 2)) * mx.sqrt(mx.sum(ref_embedding ** 2))
        )
        
        # Check if similarity is above threshold
        similarity_value = similarity.item()
        if similarity_value > threshold:
            matches.append((name, similarity_value))
    
    # Sort matches by similarity (highest first)
    return sorted(matches, key=lambda x: x[1], reverse=True)

# Visualize face recognition results
def visualize_face_matches(
    unknown_face_path: str,
    matches: List[Tuple[str, float]],
    reference_faces: List[Tuple[str, str]]
):
    """Visualize the unknown face and its matches"""
    import cv2
    import matplotlib.pyplot as plt
    
    # Get paths of matching faces
    matching_faces = []
    for name, similarity in matches:
        for ref_name, ref_path in reference_faces:
            if name == ref_name:
                matching_faces.append((ref_path, name, similarity))
                break
    
    # Number of faces to display (unknown + matches)
    n_faces = 1 + len(matching_faces)
    
    fig = plt.figure(figsize=(4 * n_faces, 5))
    
    # Show unknown face
    unknown_img = cv2.imread(unknown_face_path)
    unknown_img = cv2.cvtColor(unknown_img, cv2.COLOR_BGR2RGB)
    ax = fig.add_subplot(1, n_faces, 1)
    ax.imshow(unknown_img)
    ax.set_title("Unknown Face")
    ax.axis('off')
    
    # Show matching faces
    for i, (face_path, name, similarity) in enumerate(matching_faces):
        img = cv2.imread(face_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        ax = fig.add_subplot(1, n_faces, i + 2)
        ax.imshow(img)
        ax.set_title(f"{name}\nSimilarity: {similarity:.4f}")
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

In [19]:
def train_facenet(
    model: FaceNet,
    train_data_path: str,
    val_data_path: Optional[str] = None,
    epochs: int = 10,
    batch_size: int = 32,
    learning_rate: float = 0.001,
    margin: float = 0.2
):
    """
    Train FaceNet model using triplet loss.
    
    Args:
        model: FaceNet model
        train_data_path: Path to training data directory
        val_data_path: Path to validation data directory (optional)
        epochs: Number of training epochs
        batch_size: Batch size
        learning_rate: Learning rate
        margin: Margin for triplet loss
        
    Returns:
        Trained model
    """
    # Set up optimizer
    optimizer = optim.Adam(learning_rate=learning_rate)
    
    # Set up loss function
    triplet_loss = TripletLoss(margin=margin)
    
    # Define training step
    @mx.compile
    def train_step(model, images, labels):
        def loss_fn(model):
            embeddings = model(images, training=True)
            loss = triplet_loss(embeddings, labels)
            return loss
        
        loss, grads = nn.value_and_grad(loss_fn)(model)
        optimizer.update(model, grads)
        return loss
    
    # TODO: Implement data loading and training loop
    # This would require implementing a DataLoader for MLX
    # For brevity, only the core training step is shown
    
    print("Training function defined - actual implementation would require a DataLoader")
    return model

# Simple data loader for triplet generation
def create_triplet_batch(face_paths, person_ids, batch_size=32):
    """
    Create a batch of triplets for training.
    
    Args:
        face_paths: List of face image paths
        person_ids: List of person IDs corresponding to face_paths
        batch_size: Number of triplets in the batch
        
    Returns:
        Tuple of (anchor_images, positive_images, negative_images)
    """
    import random
    
    # Group faces by person
    person_to_faces = {}
    for face_path, person_id in zip(face_paths, person_ids):
        if person_id not in person_to_faces:
            person_to_faces[person_id] = []
        person_to_faces[person_id].append(face_path)
    
    # Ensure each person has at least 2 face images
    valid_persons = [pid for pid, faces in person_to_faces.items() if len(faces) >= 2]
    
    if len(valid_persons) < 2:
        raise ValueError("Need at least 2 persons with 2+ images each for triplet generation")
    
    anchors = []
    positives = []
    negatives = []
    
    for _ in range(batch_size):
        # Select anchor person
        anchor_person = random.choice(valid_persons)
        
        # Select anchor and positive (different images of same person)
        anchor_face, positive_face = random.sample(person_to_faces[anchor_person], 2)
        
        # Select negative person (different from anchor)
        negative_persons = [p for p in valid_persons if p != anchor_person]
        negative_person = random.choice(negative_persons)
        
        # Select negative face
        negative_face = random.choice(person_to_faces[negative_person])
        
        # Add to batch
        anchors.append(preprocess_image(anchor_face))
        positives.append(preprocess_image(positive_face))
        negatives.append(preprocess_image(negative_face))
    
    # Stack images into batches
    anchor_batch = mx.stack(anchors)
    positive_batch = mx.stack(positives)
    negative_batch = mx.stack(negatives)
    
    return anchor_batch, positive_batch, negative_batch

In [20]:
# Initialize FaceNet model
facenet = FaceNet(embedding_size=128)
print("FaceNet model successfully initialized")

# Model summary
def print_model_summary(model):
    """Print a summary of the model parameters"""
    total_params = 0
    for name, param in model.parameters().items():
        param_count = np.prod(param.shape)
        total_params += param_count
        print(f"{name}: {param.shape}, {param_count:,} parameters")
    
    print(f"\nTotal parameters: {total_params:,}")

print_model_summary(facenet)

FaceNet model successfully initialized


AttributeError: 'dict' object has no attribute 'shape'

In [13]:
# Define paths
train_data_path = "path/to/training/data"
val_data_path = "path/to/validation/data"

# Train the model
trained_model = train_facenet(
    model=facenet,
    train_data_path=train_data_path,
    val_data_path=val_data_path,
    epochs=10,
    batch_size=32,
    learning_rate=0.001,
    margin=0.2
)

# Save the trained model
mx.save("facenet_mlx_model.npz", trained_model.parameters())

NameError: name 'facenet' is not defined