# Google Colab Turtle Recognition Pipeline - Tag a Turtle
## 1. Introduction
The goal of this pipeline is to identify sea turtles by their facial features using machine learning models. It processes turtle face images, creates embeddings for these faces, and matches them to known turtle identities.

## 2. Requirements
* Environment: Google Colab

* Inputs:
A gallery of turtle images (either flat or by identity).
A query folder with turtle images for identification.

* Outputs:
A gallery per query image showing ranked turtle images by similarity.

* Modularity: 
The pipeline is designed to be modular, allowing easy swapping of models and processing steps.

## Configuration
### Process configuration
Parameters that impact recognition results

In [1]:

# The resized image to center crop from for the embedding model
resize_scale = 256

# Resize scale before passing to Detection model
full_resize_n = 1920

# Combine all images of a particular turtle from the database into one for improved performance
aggregate_dataset_identities = False

# Combine all query images of the same turtle into one for improved performance
aggregate_query_identities = False

# The amount of most similar images to be shown (top 10, or top 5 most similar turtles)
top_n = 10

### Parallelism configuration
Options that impact system speed and resource usage

In [2]:
# Larger batch sizes speed up processing but may overflow memory.
inference_batch_size = 8

use_gpu = True

### Paths configuration
Only for developers

In [3]:
dataset_path = "/home/delta/Documents/Turtles/dataset_May15th/train/reid"
yolo_checkpoint_path = '/home/delta/Documents/Turtles/dataset_ops/Tag-A-Turtle/weights/best.pt'
resnet101_checkpoint_path = "/home/delta/Documents/Turtles/Proxy-Anchor/logs/best_wobbly-pond-226.pth"
query_folder_path = "/home/delta/Documents/Turtles/Proxy-Anchor/query_images"

# Path of the embedded dataset. This embedding file has to be regenerated if the model is updated.
saved_embeddings_path = "./RecognitionCache/embeddings/embeddings.pth"
# Path of the generated metadata
auto_metadata_path = f"/home/delta/Documents/Turtles/Proxy-Anchor/RecognitionCache/datasets/auto_{dataset_path.split('/')[-1]}.csv"
metadata_path = None

In [4]:
# testing
error_crops = 0
def set_error_crops(n):
    global error_crops
    error_crops = n

## Imports

In [5]:
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch.nn.init as init
from torchvision.models import resnet101
import torch.utils.model_zoo as model_zoo
import os
import pandas as pd
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from PIL import Image
import torch
from torchvision import models, transforms
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import matplotlib.image as mpimg


import pandas as pd
import os
from PIL import Image
import cv2
from ultralytics import YOLO
import torch
import torch.nn.functional as F
from ultralytics import YOLO
from ultralytics.data.loaders import LoadTensor

import torch
from ultralytics import YOLO
from ultralytics.data.loaders import LoadTensor
import re
import csv
# determine device


## Set the device on which to run

In [6]:
# Set the device to 'cuda' if a GPU is available, otherwise default to 'cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = device if use_gpu else "cpu"
if use_gpu and device == "cpu":
    print("GPU not available.")
# Verify the device being used
print(f"Using device: {device}")

Using device: cuda


## Utils

In [7]:
def extract_labels(file_path):
    # Extract the filename from the full file path
    filename = os.path.basename(file_path)
    
    # Check for format 1: "03_066_R_2003_1.jpg" (any extension, underscores included)
    match1 = re.match(r"(\d{2}[-_]\d{3})_([A-Za-z])_\d{4}_\d+\.[a-zA-Z]+$", filename)
    if match1:
        return [match1.group(1), match1.group(2)]
    
    # Check for format 2: "03_066 R.JPG" (any extension, underscores included)
    match2 = re.match(r"(\d{2}[-_]\d{3})\s([A-Za-z])\.[a-zA-Z]+$", filename)
    if match2:
        return [match2.group(1), match2.group(2)]
    
    # Return None if no match found
    return None

In [8]:
os.makedirs("./RecognitionCache/datasets", exist_ok=True)
os.makedirs("./RecognitionCache/embeddings", exist_ok=True)


## Generating a metadata for the dataset

In [9]:
def generate_data_folder_metadata(root_path):
    file_paths = []
    
    # Walk through the root directory and its subdirectories
    for subdir, _, files in os.walk(root_path):
        for file in files:
            # Join the subdirectory and file name to get the full path
            file_paths.append(os.path.join(subdir, file))
    
    data = []
    
    for file_path in file_paths:
        labels = extract_labels(file_path)    
        # Add the row to the new format data
        data.append({
            'bonaire_turtle_id': labels[0],
            'side': labels[1],
            'filename': file_path,
        })
    df = pd.DataFrame(data)    
    df.to_csv(auto_metadata_path, index=False)
    return df

## Generating metadata for the query files

In [10]:
def generate_query_folder_metadata(root_path):
    data = []
    
    # Walk through the root directory and its subdirectories
    for subdir, _, files in os.walk(root_path):
        for file in files:
            # Join the subdirectory and file name to get the full path
            data.append({
                'bonaire_turtle_id': f"{subdir}",
                'side': "U",
                'filename': os.path.join(subdir, file),
            })
    df = pd.DataFrame(data)    
    return df

## Make sure there is metadata

In [11]:
def ensure_metadata():
    data_path = None
    
    # Try loading user provided metadata
    if metadata_path:
        data_path = metadata_path
        
    # If it fails, load auto metadata
    else:
        data_path = auto_metadata_path
    data_exists = os.path.isfile(data_path)
    
    # If there is no metadata, generate it
    if not data_exists:
        print(f"There is no metadata file {data_path}")
        generate_data_folder_metadata(dataset_path)
        data_path = auto_metadata_path

## Resize image tensor batch from any size to size N

In [12]:
def resize_images_to_n_padded(image_batch, n=224):
    """
    Takes a batch of image tensors, resizes the smaller dimension to 'n', and pads 
    the larger dimension with black to fit a square of size n x n. Returns a batch of 
    images in tensor form.

    Args:
        image_batch (torch.Tensor): A tensor batch of images with shape (B, C, H, W), where B is the batch size, 
                                   C is the number of channels (3 for RGB), H is the height, and W is the width.
        n (int): The target size for the square image (default 224).
    Returns:
        torch.Tensor: A tensor batch of images, each of size n x n.
    """
    transform = transforms.Compose([
        transforms.Resize((n, n)),
        transforms.ToTensor()
    ])

    batch_images = []
    for img_tensor in image_batch:
        img = transforms.ToPILImage()(img_tensor)

        width, height = img.size
        if width == height:
            # If the image is already square, simply resize
            transformed_image = transform(img)
        elif width < height:
            # Resize height to n, calculate padding for width
            new_height = n
            new_width = int(width * (new_height / height))
            padding_left = (n - new_width) // 2
            padding_right = n - new_width - padding_left
            transformed_image = transforms.Compose([
                transforms.Resize((new_height, new_width)),
                transforms.Pad((padding_left, 0, padding_right, 0), fill=0, padding_mode='constant'),
                transforms.ToTensor()
            ])(img) 
        else:  # height < width
            # Resize width to n, calculate padding for height
            new_width = n
            new_height = int(height * (new_width / width))
            padding_top = (n - new_height) // 2
            padding_bottom = n - new_height - padding_top
            transformed_image = transforms.Compose([
                transforms.Resize((new_height, new_width)),
                transforms.Pad((0, padding_top, 0, padding_bottom), fill=0, padding_mode='constant'),
                transforms.ToTensor()
            ])(img)

        batch_images.append(transformed_image)

    # Stack images into a single tensor batch
    return torch.stack(batch_images)


## Dataset Class

In [13]:

class TurtlesDataset(Dataset):
    def __init__(self, mode: str, root: str = None, transform=None, ignoreThreshold=0, data=None):
        self.mode = mode.lower()
        self.transform = transform
        data_path = metadata_path if metadata_path else auto_metadata_path
        if not data:
            self.root = root
            meta_path = os.path.join(self.root, data_path)
            if not os.path.isfile(meta_path):
                raise FileNotFoundError(f"No metadata file found (expected {data_path}).")
            data_df = pd.read_csv(meta_path)
        else:
            data_df = data
            
        print(f"Dataset size: {len(data_df)}")
        grouped = data_df.groupby(['bonaire_turtle_id', 'side'])
        group_counts = grouped.size()

        if ignoreThreshold > 0: 
            multi_sample_groups = group_counts[group_counts > ignoreThreshold]
            print("Removing single-sample classes for BonaireTurtlesDataset.")
        else: multi_sample_groups = group_counts

        self.im_paths, self._y_strs, self.positions = [], [], []
        for (turtle_id, side), group_df in grouped:
            if (turtle_id, side) not in multi_sample_groups.index:
                continue

            for _, row in group_df.iterrows():
                filename = row.get("filename", "").strip()
                if not filename:
                    continue

                img_path = filename
                if not os.path.isfile(img_path):
                    continue

                identity = f"{turtle_id}"
                self.im_paths.append(img_path)
                self._y_strs.append(identity)
                self.positions.append(side)

        if not self.im_paths:
            raise RuntimeError("Dataset is empty.")

        all_classes = sorted(set(self._y_strs))

        filtered = [i for i, lbl in enumerate(self._y_strs) if lbl in all_classes]
        self.im_paths = [self.im_paths[i] for i in filtered]
        self._y_strs = [self._y_strs[i] for i in filtered]

        self.classes = sorted(all_classes)
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        self.ys = [self.class_to_idx[s] for s in self._y_strs]
        print("Data length:", len(self.ys))

    def __getitem__(self, index):
        img = Image.open(self.im_paths[index]).convert("RGB")
        if self.transform:
            img = self.transform(img).unsqueeze(0)
            img = resize_images_to_n_padded(img, full_resize_n).squeeze(0)
        target = self.ys[index]
        return img, target
    
    def get_path(self, index):
        return self.im_paths[index]

    def __len__(self):
        return len(self.ys)


## Get Labels from Dataset/DataLoader

In [14]:
def get_dataloader_labels(dataloader):
    """
    This function takes a PyTorch DataLoader and returns a tensor containing all the labels.
    
    Args:
        dataloader (torch.utils.data.DataLoader): A DataLoader object containing the dataset.

    Returns:
        torch.Tensor: A tensor containing all the labels.
    """
    all_labels = []

    for _, labels in dataloader:
        all_labels.append(labels)

    # Concatenate all labels into a single tensor
    return torch.cat(all_labels, dim=0)


## Resnet Implementation

In [15]:


class Resnet101(nn.Module):
    def __init__(self,embedding_size, pretrained=True, is_norm=True, bn_freeze = True):
        super(Resnet101, self).__init__()

        self.model = resnet101(pretrained)
        self.is_norm = is_norm
        self.embedding_size = embedding_size
        self.num_ftrs = self.model.fc.in_features
        self.model.gap = nn.AdaptiveAvgPool2d(1)
        self.model.gmp = nn.AdaptiveMaxPool2d(1)

        self.model.embedding = nn.Linear(self.num_ftrs, self.embedding_size)
        self._initialize_weights()

        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)

    def l2_norm(self,input):
        input_size = input.size()
        buffer = torch.pow(input, 2)

        normp = torch.sum(buffer, 1).add_(1e-12)
        norm = torch.sqrt(normp)

        _output = torch.div(input, norm.view(-1, 1).expand_as(input))

        output = _output.view(input_size)

        return output

    def forward(self, x):
        """
        Crop and return a batch of cropped images from the input tensor batch using YOLOv11.
        If no object is detected, return the original image for the respective batch item.

        Args:
            image_batch (torch.Tensor): Input tensor of shape (B, C, H, W).
            The image batch should be normalized to 
            resnet_mean = [0.485, 0.456, 0.406]
            resnet_std = [0.229, 0.224, 0.225]

        Returns:
            torch.Tensor: Tensor of cropped images or original images if no detections.
        """
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        avg_x = self.model.gap(x)
        max_x = self.model.gmp(x)

        x = max_x + avg_x
        x = x.view(x.size(0), -1)
        x = self.model.embedding(x)
        
        if self.is_norm:
            x = self.l2_norm(x)
            
        return x

    def _initialize_weights(self):
        init.kaiming_normal_(self.model.embedding.weight, mode='fan_out')
        init.constant_(self.model.embedding.bias, 0)


## Resnet101 Model

In [16]:
resnet101_model = Resnet101(embedding_size=512, pretrained=True, is_norm=True, bn_freeze = True)

# Path to checkpoint
checkpoint = torch.load(resnet101_checkpoint_path)
resnet101_model.load_state_dict(checkpoint)
resnet101_model.eval()



Resnet101(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          

## YOLO Detection 

In [17]:
yolo_model = YOLO(yolo_checkpoint_path).to(device)  # You can specify a different model if needed

## Batch Process image tensors with YOLO
The function below takes a image tensor batch. The images are then cropped around the turtle head. 
If no turtle head is found the original image is used.

The function returns the tensor batch, cropped around the turtles' heads. 

In [18]:
import torch
from PIL import Image, ImageDraw
from torchvision import transforms
import matplotlib.pyplot as plt

def batch_crop_yolov11(image_batch: torch.Tensor, model: YOLO, conf_threshold: float = 0.5, device: str = 'cuda'):
    global error_crops
    """
    Crop and return a batch of cropped images from the input tensor batch using YOLOv11.
    If no object is detected, return the original image for the respective batch item.
    Additionally, display the image with bounding boxes if there are more than 2 boxes.

    Args:
        image_batch (torch.Tensor): Input tensor of shape (B, C, H, W).
        model (YOLO): Preloaded YOLOv11 model.
        conf_threshold (float): Confidence threshold for detections.
        device (str): Device to run the model on ('cuda' or 'cpu').

    Returns:
        torch.Tensor: Tensor of cropped images or original images if no detections.
    """
    # Ensure the model is on the correct device
    model.to(device)

    # Load and preprocess the image batch
    images = image_batch.to(device)

    # Perform inference
    results = model.predict(source=images, conf=conf_threshold, device=device, verbose=False)

    # Initialize a list to store cropped images
    cropped_images = []

    # Iterate over each result
    for result, image in zip(results, images):
        if result.boxes is not None and len(result.boxes.xyxy) > 0:
            boxes = result.boxes.xyxy.int().cpu().tolist()

            # if len(boxes) > 2:
            #     # Convert the image to PIL format for display
            #     pil_image = transforms.ToPILImage()(image.cpu())
            #     draw = ImageDraw.Draw(pil_image)
            #     # Draw each box on the image
            #     for box in boxes:
            #         x_min, y_min, x_max, y_max = box
            #         draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)

            #     # Show the image with bounding boxes
            #     plt.imshow(pil_image)
            #     plt.show()
            if len(boxes) != 1:
                error_crops = error_crops + 1
                
            # Crop the boxes and add them to the list
            x_min, y_min, x_max, y_max = boxes[0]   
            cropped_image = image[:, y_min:y_max, x_min:x_max]
            img = resize_images_to_n_padded(cropped_image.unsqueeze(0), resize_scale)
            cropped_images.append(img.squeeze(0))
                
        else:
            # If no detections, append the original image to the batch
            img = resize_images_to_n_padded(image.unsqueeze(0), resize_scale)
            cropped_images.append(img.squeeze(0))

    # Stack the cropped images into a single tensor
    cropped_images_tensor = torch.stack(cropped_images)

    return cropped_images_tensor


## Standard Transform before Embedding model

In [19]:
def make_transform(original = 256, output=224):
    resnet_sz_resize = original
    resnet_sz_crop = output
    resnet_mean = [0.485, 0.456, 0.406]
    resnet_std = [0.229, 0.224, 0.225]
    resnet_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(resnet_sz_resize),
        transforms.CenterCrop(resnet_sz_crop),
        transforms.ToTensor(),
        transforms.Normalize(mean=resnet_mean, std=resnet_std)
    ])
    return resnet_transform
resnet_transform = make_transform(original=resize_scale)

## Resize, crop and normalize tensors for embedding model

In [20]:
def pre_embed_transform(image_tensor_batch):
    processed_images = []
    for image_tensor in image_tensor_batch:
        processed_images.append(resnet_transform(image_tensor))
    return torch.stack(processed_images)
    

## Inference Pipeline
Handles tensorized images in batches, applying all the inferences and transforms required to produce the embeddings.

In [21]:
def infer_pipeline(image_tensor):
    global yolo_model, resnet101_model, error_crops
    # Detect face in the image
    faces_tensor = batch_crop_yolov11(image_tensor, yolo_model)
    print("Number of cropping errors", error_crops)
    #resize the images to work for resnet
    faces_tensor_resized = resize_images_to_n_padded(faces_tensor, n=resize_scale)
    # crop and normalize before embedding model
    normalized_face_tensor = pre_embed_transform(faces_tensor_resized)
    
    # Generate embedding for the detected face
    embeddings = resnet101_model(normalized_face_tensor)
    
    return embeddings.cpu()

## Function to run the inference pipeline over the full dataset

In [22]:
def process_images_to_embeddings(data_loader):
    processed_batches = []
    print("Processing data...")
    for batch_idx, (inputs, targets) in enumerate(tqdm(data_loader, desc="Processing data")):

        processed_batch = infer_pipeline(inputs)
        processed_batches.append(processed_batch.cpu())
    return torch.stack(processed_batches).cpu()

In [23]:

def turtle_similarities(query_embedding, gallery_embeddings):
    cosine_sim = F.cosine_similarity(query_embedding, gallery_embeddings)
    
    return cosine_sim


In [24]:

def plot_recognition_results(query_image, similar_images, titles):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, len(similar_images)+1, 1)
    plt.imshow(query_image)
    plt.title("Query")
    
    for idx, sim_img in enumerate(similar_images):
        plt.subplot(1, len(similar_images)+1, idx+2)
        plt.imshow(mpimg.imread(sim_img))
        plt.title(titles[idx])
    
    plt.show()


## Embeddings aggregator class

In [25]:
class EmbeddingAggregator(nn.Module):
    """
    Average embeddings in the batch per turtle identity to produce more accurate embeddings.

    Parameters
    ----------
    reduction : str
        'mean' (default) or 'weighted'.
        When 'weighted', `weights` must be supplied in forward().
    return_index_map : bool
        If True, also returns a tensor that maps each original sample
        to its aggregated row – handy for gathering losses.
    """

    def __init__(self, reduction: str = "mean", return_index_map: bool = False):
        super().__init__()
        assert reduction in {"mean", "weighted"}
        self.reduction = reduction
        self.return_index_map = return_index_map

    @torch.no_grad()
    def forward(
        self,
        emb: torch.Tensor,             # shape (B, D)
        labels: torch.Tensor,          # shape (B,)
        weights: torch.Tensor | None = None
    ):
        device = emb.device
        labels = labels.to(device)

        # ---- gather bookkeeping -------------------------------------------------
        # unique ids and position of each sample’s aggregated row
        uniq_ids, inv = torch.unique(labels, return_inverse=True)
        num_ids = uniq_ids.size(0)

        if self.reduction == "mean":
            # quick path – use scatter_add then divide by counts
            summed = torch.zeros(num_ids, emb.size(1), device=device).scatter_add_(0,
                      inv.unsqueeze(-1).expand_as(emb), emb)
            counts = torch.bincount(inv, minlength=num_ids).unsqueeze(1)
            agg = summed / counts.clamp_min(1)

        else:  # weighted mean
            assert weights is not None, "`weights` required for weighted reduction"
            weights = weights.to(device).unsqueeze(1)  # (B, 1)
            w_sum = torch.zeros(num_ids, 1, device=device).scatter_add_(0, inv.unsqueeze(-1), weights)
            w_emb = torch.zeros_like(summed).scatter_add_(0, inv.unsqueeze(-1).expand_as(emb), emb * weights)
            agg = w_emb / w_sum.clamp_min(1e-8)

        if self.return_index_map:
            return agg, uniq_ids, inv
        return agg, uniq_ids, None



## Sort embeddings by similarity with their labels

In [26]:
def sort_em_by_sim(similarities, query_labels):
    similarities_and_labels = [(similarity.item(), query_labels[i].item()) for i, similarity in enumerate(similarities)]

    # Sort the list by similarity in descending order
    similarities_and_labels_sorted = sorted(similarities_and_labels, key=lambda x: x[0], reverse=True)

    # Optionally, separate the sorted similarities and labels back into two lists/tensors
    sorted_similarities = [pair[0] for pair in similarities_and_labels_sorted]
    sorted_labels = [pair[1] for pair in similarities_and_labels_sorted]

    return torch.tensor(sorted_similarities), torch.tensor(sorted_labels)

## Retrieve the images in a dataset that match the label

In [27]:
def find_matching_indices(string_list, query):
    # Create a list of indices where the string matches the query
    matching_indices = [i for i, s in enumerate(string_list) if s == query]
    return matching_indices

## Basic tensor transform

In [28]:
to_tensor_transform = transforms.Compose([
    transforms.ToTensor(),
])

## Main Pipeline Cell

In [29]:
embeddings_exist = os.path.exists(saved_embeddings_path)
error_crops = 0
# Load dataset embeddings
dataset_embeddings = None
if embeddings_exist:
    dataset_embeddings, dataset_labels = torch.load(saved_embeddings_path)
else:
    generate_data_folder_metadata(dataset_path)
    dataset = TurtlesDataset(root=dataset_path, mode="all", transform=to_tensor_transform)
    data_loader = DataLoader(dataset, batch_size=inference_batch_size, pin_memory=True, shuffle=False, num_workers=1)
    dataset_embeddings = process_images_to_embeddings(data_loader)
    dataset_labels = get_dataloader_labels(data_loader)
    torch.save((dataset_embeddings, dataset_labels), saved_embeddings_path)


# Load query embeddings
query_metadata = generate_query_folder_metadata(query_folder_path)
query_dataset = TurtlesDataset(data=query_metadata, mode="all", transform=to_tensor_transform)
query_loader = DataLoader(query_dataset, batch_size=inference_batch_size, pin_memory=True, shuffle=False, num_workers=1)
query_embeddings = process_images_to_embeddings(query_loader)
query_labels = get_dataloader_labels(query_loader)

# Aggregate embeddings by identity if enabled
aggregator = EmbeddingAggregator().to(device)
if aggregate_dataset_identities:
    dataset_embeddings, dataset_labels, _ = aggregator(dataset_embeddings, labels=dataset_labels)
if aggregate_query_identities:
    query_embeddings, query_labels, _ = aggregator(query_embeddings, labels=query_labels)
    
print("Comparing embeddings...")
for query_index in range(query_embeddings.shape[0]):
    query_em = query_embeddings[query_index]
    query_label = query_labels[query_index]
    
    similarities = turtle_similarities(query_em, dataset_embeddings)
    sorted_similarities, sorted_labels = sort_em_by_sim(similarities, dataset_labels)[:top_n]
    images = []
    images_labels = []
    for label in sorted_labels:
        image_index = find_matching_indices(dataset_labels, label)[0]
        
        image = dataset[image_index]
        images.append(image)
        images_labels.append(label)
    query_image_index = find_matching_indices(query_labels, query_label)[0]
    query_image = query_dataset[query_image_index]
    plot_recognition_results(query_image, images, images_labels)
     

    

Dataset size: 8396
Data length: 8396
Processing data...


Processing data:   0%|          | 0/1050 [00:00<?, ?it/s]

Number of cropping errors 0


Processing data:   0%|          | 1/1050 [00:02<42:15,  2.42s/it]

Number of cropping errors 2


Processing data:   0%|          | 2/1050 [00:03<25:25,  1.46s/it]

Number of cropping errors 5


Processing data:   0%|          | 4/1050 [00:04<18:10,  1.04s/it]

Number of cropping errors 8


Processing data:   0%|          | 5/1050 [00:05<16:15,  1.07it/s]

Number of cropping errors 12


Processing data:   1%|          | 6/1050 [00:06<14:51,  1.17it/s]

Number of cropping errors 17
Number of cropping errors 22


Processing data:   1%|          | 8/1050 [00:07<14:05,  1.23it/s]

Number of cropping errors 24


Processing data:   1%|          | 9/1050 [00:08<13:37,  1.27it/s]

Number of cropping errors 28


Processing data:   1%|          | 10/1050 [00:09<13:48,  1.26it/s]

Number of cropping errors 32


Processing data:   1%|          | 11/1050 [00:10<15:14,  1.14it/s]

Number of cropping errors 38


Processing data:   1%|          | 12/1050 [00:11<14:56,  1.16it/s]

Number of cropping errors 40
Number of cropping errors 46


Processing data:   1%|▏         | 14/1050 [00:12<14:37,  1.18it/s]

Number of cropping errors 49


Processing data:   1%|▏         | 15/1050 [00:13<14:50,  1.16it/s]

Number of cropping errors 51


Processing data:   2%|▏         | 16/1050 [00:14<14:46,  1.17it/s]

Number of cropping errors 53


Processing data:   2%|▏         | 17/1050 [00:15<15:43,  1.09it/s]

Number of cropping errors 56


Processing data:   2%|▏         | 18/1050 [00:16<17:13,  1.00s/it]

Number of cropping errors 59


Processing data:   2%|▏         | 19/1050 [00:17<16:09,  1.06it/s]

Number of cropping errors 63


Processing data:   2%|▏         | 20/1050 [00:18<17:18,  1.01s/it]

Number of cropping errors 65


Processing data:   2%|▏         | 21/1050 [00:19<16:50,  1.02it/s]

Number of cropping errors 69


Processing data:   2%|▏         | 22/1050 [00:20<17:09,  1.00s/it]

Number of cropping errors 73


Processing data:   2%|▏         | 23/1050 [00:21<17:11,  1.00s/it]

Number of cropping errors 75


Processing data:   2%|▏         | 24/1050 [00:22<17:01,  1.00it/s]

Number of cropping errors 80


Processing data:   2%|▏         | 25/1050 [00:23<16:58,  1.01it/s]

Number of cropping errors 83


Processing data:   2%|▏         | 26/1050 [00:24<17:35,  1.03s/it]

Number of cropping errors 85


Processing data:   3%|▎         | 27/1050 [00:25<16:06,  1.06it/s]

Number of cropping errors 89


Processing data:   3%|▎         | 28/1050 [00:26<15:12,  1.12it/s]

Number of cropping errors 93


Processing data:   3%|▎         | 29/1050 [00:27<14:17,  1.19it/s]

Number of cropping errors 96


Processing data:   3%|▎         | 30/1050 [00:28<14:49,  1.15it/s]

Number of cropping errors 100


Processing data:   3%|▎         | 31/1050 [00:29<19:23,  1.14s/it]

Number of cropping errors 105


Processing data:   3%|▎         | 32/1050 [00:30<18:18,  1.08s/it]

Number of cropping errors 109


Processing data:   3%|▎         | 33/1050 [00:32<19:08,  1.13s/it]

Number of cropping errors 111


Processing data:   3%|▎         | 34/1050 [00:33<19:29,  1.15s/it]

Number of cropping errors 114


Processing data:   3%|▎         | 35/1050 [00:34<19:38,  1.16s/it]

Number of cropping errors 118


Processing data:   3%|▎         | 36/1050 [00:35<20:12,  1.20s/it]

Number of cropping errors 122


Processing data:   4%|▎         | 37/1050 [00:37<21:07,  1.25s/it]

Number of cropping errors 124


Processing data:   4%|▎         | 38/1050 [00:38<21:34,  1.28s/it]

Number of cropping errors 129


Processing data:   4%|▎         | 39/1050 [00:39<21:55,  1.30s/it]

Number of cropping errors 132


Processing data:   4%|▍         | 40/1050 [00:41<22:04,  1.31s/it]

Number of cropping errors 134


Processing data:   4%|▍         | 41/1050 [00:42<21:49,  1.30s/it]

Number of cropping errors 138


Processing data:   4%|▍         | 42/1050 [00:43<19:30,  1.16s/it]

Number of cropping errors 142


Processing data:   4%|▍         | 43/1050 [00:44<20:17,  1.21s/it]

Number of cropping errors 144


Processing data:   4%|▍         | 44/1050 [00:45<20:09,  1.20s/it]

Number of cropping errors 148


Processing data:   4%|▍         | 45/1050 [00:47<20:42,  1.24s/it]

Number of cropping errors 151


Processing data:   4%|▍         | 46/1050 [00:48<20:29,  1.22s/it]

Number of cropping errors 153
Number of cropping errors 154


Processing data:   5%|▍         | 48/1050 [00:50<19:48,  1.19s/it]

Number of cropping errors 155


Processing data:   5%|▍         | 49/1050 [00:51<20:38,  1.24s/it]

Number of cropping errors 157


Processing data:   5%|▍         | 50/1050 [00:53<21:19,  1.28s/it]

Number of cropping errors 161


Processing data:   5%|▍         | 51/1050 [00:54<20:20,  1.22s/it]

Number of cropping errors 164


Processing data:   5%|▍         | 52/1050 [00:55<19:57,  1.20s/it]

Number of cropping errors 164


Processing data:   5%|▌         | 53/1050 [00:57<21:52,  1.32s/it]

Number of cropping errors 167


Processing data:   5%|▌         | 54/1050 [00:58<21:59,  1.32s/it]

Number of cropping errors 168


Processing data:   5%|▌         | 55/1050 [00:59<22:15,  1.34s/it]

Number of cropping errors 171


Processing data:   5%|▌         | 56/1050 [01:01<24:30,  1.48s/it]

Number of cropping errors 173
Number of cropping errors 174


Processing data:   6%|▌         | 58/1050 [01:03<18:54,  1.14s/it]

Number of cropping errors 178


Processing data:   6%|▌         | 59/1050 [01:04<16:25,  1.01it/s]

Number of cropping errors 181


Processing data:   6%|▌         | 60/1050 [01:05<16:31,  1.00s/it]

Number of cropping errors 184


Processing data:   6%|▌         | 61/1050 [01:06<20:13,  1.23s/it]

Number of cropping errors 190


Processing data:   6%|▌         | 62/1050 [01:07<17:55,  1.09s/it]

Number of cropping errors 192


Processing data:   6%|▌         | 63/1050 [01:08<17:35,  1.07s/it]

Number of cropping errors 195


Processing data:   6%|▌         | 64/1050 [01:09<15:42,  1.05it/s]

Number of cropping errors 198


Processing data:   6%|▌         | 65/1050 [01:10<14:27,  1.14it/s]

Number of cropping errors 200
Number of cropping errors 204


Processing data:   6%|▋         | 66/1050 [01:10<13:34,  1.21it/s]

Number of cropping errors 209


Processing data:   6%|▋         | 67/1050 [01:11<13:00,  1.26it/s]

Number of cropping errors 213


Processing data:   7%|▋         | 69/1050 [01:12<12:52,  1.27it/s]

Number of cropping errors 218
Number of cropping errors 221


Processing data:   7%|▋         | 70/1050 [01:14<14:36,  1.12it/s]

Number of cropping errors 227


Processing data:   7%|▋         | 72/1050 [01:16<15:37,  1.04it/s]

Number of cropping errors 231


Processing data:   7%|▋         | 73/1050 [01:17<15:34,  1.05it/s]

Number of cropping errors 234


Processing data:   7%|▋         | 74/1050 [01:18<16:18,  1.00s/it]

Number of cropping errors 237


Processing data:   7%|▋         | 75/1050 [01:19<15:40,  1.04it/s]

Number of cropping errors 238


Processing data:   7%|▋         | 76/1050 [01:20<16:26,  1.01s/it]

Number of cropping errors 242


Processing data:   7%|▋         | 77/1050 [01:21<17:14,  1.06s/it]

Number of cropping errors 246


Processing data:   7%|▋         | 78/1050 [01:22<18:42,  1.15s/it]

Number of cropping errors 248


Processing data:   8%|▊         | 79/1050 [01:24<20:14,  1.25s/it]

Number of cropping errors 251


Processing data:   8%|▊         | 80/1050 [01:25<20:56,  1.30s/it]

Number of cropping errors 253
Number of cropping errors 255


Processing data:   8%|▊         | 82/1050 [01:28<21:43,  1.35s/it]

Number of cropping errors 259


Processing data:   8%|▊         | 83/1050 [01:29<22:44,  1.41s/it]

Number of cropping errors 263


Processing data:   8%|▊         | 84/1050 [01:31<22:51,  1.42s/it]

Number of cropping errors 270


Processing data:   8%|▊         | 85/1050 [01:32<22:14,  1.38s/it]

Number of cropping errors 273


Processing data:   8%|▊         | 85/1050 [01:33<17:45,  1.10s/it]


KeyboardInterrupt: 