# Google Colab Turtle Recognition Pipeline - Tag a Turtle
## 1. Introduction
The goal of this pipeline is to identify sea turtles by their facial features using machine learning models. It processes turtle face images, creates embeddings for these faces, and matches them to known turtle identities.

## 2. Requirements
* Environment: Google Colab

* Inputs:
A gallery of turtle images (either flat or by identity).
A query folder with turtle images for identification.

* Outputs:
A gallery per query image showing ranked turtle images by similarity.

* Modularity: 
The pipeline is designed to be modular, allowing easy swapping of models and processing steps.

## Configuration
### Process configuration
Parameters that impact recognition results

In [3402]:

# The resized image to center crop from for the embedding model
resize_scale = 256

# Resize scale before passing to Detection model
full_resize_n = 1920

# Combine all images of a particular turtle from the database into one for improved performance
aggregate_dataset_identities = False

# Combine all query images of the same turtle into one for improved performance
aggregate_query_identities = False

# The amount of most similar images to be shown (top 10, or top 5 most similar turtles)
top_n = 10

### Parallelism configuration
Options that impact system speed and resource usage

In [3403]:
# Larger batch sizes speed up processing but may overflow memory.
inference_batch_size = 16

use_gpu = True

### Paths configuration
Only for developers

In [3404]:
dataset_path = "/home/delta/Documents/Turtles/dataset_May15th/train/reid"
yolo_checkpoint_path = '/home/delta/Documents/Turtles/dataset_ops/Tag-A-Turtle/weights/best.pt'
resnet101_checkpoint_path = "/home/delta/Documents/Turtles/Proxy-Anchor/logs/best_wobbly-pond-226.pth"
query_folder_path = "/home/delta/Documents/Turtles/Proxy-Anchor/query_images"

# Path of the embedded dataset. This embedding file has to be regenerated if the model is updated.
saved_embeddings_path = "./RecognitionCache/embeddings/embeddings.pth"
# Path of the generated metadata
auto_metadata_path = f"/home/delta/Documents/Turtles/Proxy-Anchor/RecognitionCache/datasets/auto_{dataset_path.split('/')[-1]}.csv"
metadata_path = None

In [3405]:
# testing
error_crops = 0
def set_error_crops(n):
    global error_crops
    error_crops = n

## Imports

In [3406]:
import torch
from tqdm import tqdm
import torch.nn.init as init
from torchvision.models import resnet101
import os
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
import matplotlib.image as mpimg
from PIL import Image
from ultralytics import YOLO

import re
# determine device

## Set the device on which to run

In [3407]:
# Set the device to 'cuda' if a GPU is available, otherwise default to 'cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = device if use_gpu else "cpu"
if use_gpu and device == "cpu":
    print("GPU not available.")
# Verify the device being used
print(f"Using device: {device}")

Using device: cuda


## Utils

In [3408]:
def extract_labels(file_path):
    # Extract the filename from the full file path
    filename = os.path.basename(file_path)
    
    # Check for format 1: "03_066_R_2003_1.jpg" (any extension, underscores included)
    match1 = re.match(r"(\d{2}[-_]\d{3})_([A-Za-z])_\d{4}_\d+\.[a-zA-Z]+$", filename)
    if match1:
        return [match1.group(1), match1.group(2)]
    
    # Check for format 2: "03_066 R.JPG" (any extension, underscores included)
    match2 = re.match(r"(\d{2}[-_]\d{3})\s([A-Za-z])\.[a-zA-Z]+$", filename)
    if match2:
        return [match2.group(1), match2.group(2)]
    
    # Return None if no match found
    return None

In [3409]:
os.makedirs("./RecognitionCache/datasets", exist_ok=True)
os.makedirs("./RecognitionCache/embeddings", exist_ok=True)


## Generating a metadata for the dataset

In [3410]:
def generate_data_folder_metadata(root_path):
    file_paths = []
    
    # Walk through the root directory and its subdirectories
    for subdir, _, files in os.walk(root_path):
        for file in files:
            # Join the subdirectory and file name to get the full path
            file_paths.append(os.path.join(subdir, file))
    
    data = []
    
    for file_path in file_paths:
        labels = extract_labels(file_path)    
        # Add the row to the new format data
        data.append({
            'bonaire_turtle_id': labels[0],
            'side': labels[1],
            'filename': file_path,
        })
    df = pd.DataFrame(data)    
    df.to_csv(auto_metadata_path, index=False)
    return df

## Generating metadata for the query files

In [3411]:
def generate_query_folder_metadata(root_path):
    data = []
    
    # Walk through the root directory and its subdirectories
    for subdir, _, files in os.walk(root_path):
        for file in files:
            # Join the subdirectory and file name to get the full path
            data.append({
                'bonaire_turtle_id': f"{subdir}",
                'side': "U",
                'filename': os.path.join(subdir, file),
            })
    df = pd.DataFrame(data)    
    return df

## Make sure there is metadata

In [3412]:
def ensure_metadata():
    data_path = None
    
    # Try loading user provided metadata
    if metadata_path:
        data_path = metadata_path
        
    # If it fails, load auto metadata
    else:
        data_path = auto_metadata_path
    data_exists = os.path.isfile(data_path)
    
    # If there is no metadata, generate it
    if not data_exists:
        print(f"There is no metadata file {data_path}")
        generate_data_folder_metadata(dataset_path)
        data_path = auto_metadata_path

## Resize image tensor batch from any size to size N

In [3413]:
def resize_image_t_to_n_padded(image_batch, n=224):
    """
    Takes a batch of image tensors, resizes the smaller dimension to 'n', and pads 
    the larger dimension with black to fit a square of size n x n. Returns a batch of 
    images in tensor form.

    Args:
        image_batch (torch.Tensor): A tensor batch of images with shape (B, C, H, W), where B is the batch size, 
                                   C is the number of channels (3 for RGB), H is the height, and W is the width.
        n (int): The target size for the square image (default 224).
    Returns:
        torch.Tensor: A tensor batch of images, each of size n x n.
    """
    transform = transforms.Compose([
        transforms.Resize((n, n)),
        transforms.ToTensor()
    ])

    batch_images = []
    for img_tensor in image_batch:
        img = transforms.ToPILImage()(img_tensor)

        width, height = img.size
        if width == height:
            # If the image is already square, simply resize
            transformed_image = transform(img)
        elif width < height:
            # Resize height to n, calculate padding for width
            new_height = n
            new_width = int(width * (new_height / height))
            padding_left = (n - new_width) // 2
            padding_right = n - new_width - padding_left
            transformed_image = transforms.Compose([
                transforms.Resize((new_height, new_width)),
                transforms.Pad((padding_left, 0, padding_right, 0), fill=0, padding_mode='constant'),
                transforms.ToTensor()
            ])(img) 
        else:  # height < width
            # Resize width to n, calculate padding for height
            new_width = n
            new_height = int(height * (new_width / width))
            padding_top = (n - new_height) // 2
            padding_bottom = n - new_height - padding_top
            transformed_image = transforms.Compose([
                transforms.Resize((new_height, new_width)),
                transforms.Pad((0, padding_top, 0, padding_bottom), fill=0, padding_mode='constant'),
                transforms.ToTensor()
            ])(img)

        batch_images.append(transformed_image)

    # Stack images into a single tensor batch
    return torch.stack(batch_images)


In [3414]:
def resize_image_to_n_padded(image_pil, n=224):
    transform = transforms.Compose([
        transforms.Resize((n, n)),
        transforms.ToTensor()
    ])

    img = image_pil

    width, height = img.size
    if width == height:
        # If the image is already square, simply resize
        transformed_image = transform(img)
    elif width < height:
        # Resize height to n, calculate padding for width
        new_height = n
        new_width = int(width * (new_height / height))
        padding_left = (n - new_width) // 2
        padding_right = n - new_width - padding_left
        transformed_image = transforms.Compose([
            transforms.Resize((new_height, new_width)),
            transforms.Pad((padding_left, 0, padding_right, 0), fill=0, padding_mode='constant'),
            transforms.ToTensor()
        ])(img) 
    else:  # height < width
        # Resize width to n, calculate padding for height
        new_width = n
        new_height = int(height * (new_width / width))
        padding_top = (n - new_height) // 2
        padding_bottom = n - new_height - padding_top
        transformed_image = transforms.Compose([
            transforms.Resize((new_height, new_width)),
            transforms.Pad((0, padding_top, 0, padding_bottom), fill=0, padding_mode='constant'),
            transforms.ToTensor()
        ])(img)


    # Stack images into a single tensor batch
    return transformed_image


## Dataset Class

In [3415]:

class TurtlesDataset(Dataset):
    def __init__(self, mode: str, root: str = None, transform=None, ignoreThreshold=0, data=None):
        self.mode = mode.lower()
        self.transform = transform
        if data is None:
            data_path = metadata_path if metadata_path else auto_metadata_path
            self.root = root
            meta_path = os.path.join(self.root, data_path)
            if not os.path.isfile(meta_path):
                raise FileNotFoundError(f"No metadata file found (expected {data_path}).")
            data_df = pd.read_csv(meta_path)
        else:
            data_df = data
        cutoff = min(16, len(data_df))
        # data_df = data_df[:cutoff]
            
        print(f"Dataset size: {len(data_df)}")
        grouped = data_df.groupby(['bonaire_turtle_id', 'side'])
        group_counts = grouped.size()

        if ignoreThreshold > 0: 
            multi_sample_groups = group_counts[group_counts > ignoreThreshold]
            print("Removing single-sample classes for BonaireTurtlesDataset.")
        else: multi_sample_groups = group_counts

        self.im_paths, self._y_strs, self.positions = [], [], []
        for (turtle_id, side), group_df in grouped:
            if (turtle_id, side) not in multi_sample_groups.index:
                continue

            for _, row in group_df.iterrows():
                filename = row.get("filename", "").strip()
                if not filename:
                    continue

                img_path = filename
                if not os.path.isfile(img_path):
                    continue

                identity = f"{turtle_id}"
                self.im_paths.append(img_path)
                self._y_strs.append(identity)
                self.positions.append(side)

        if not self.im_paths:
            raise RuntimeError("Dataset is empty.")

        all_classes = sorted(set(self._y_strs))

        filtered = [i for i, lbl in enumerate(self._y_strs) if lbl in all_classes]
        self.im_paths = [self.im_paths[i] for i in filtered]
        self._y_strs = [self._y_strs[i] for i in filtered]

        self.classes = sorted(all_classes)
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        self.ys = [self.class_to_idx[s] for s in self._y_strs]
        print("Data length:", len(self.ys))

    def __getitem__(self, index):
        img = Image.open(self.im_paths[index]).convert("RGB")
        if self.transform:
            img = self.transform(img).unsqueeze(0)
            img = resize_images_to_n_padded(img, full_resize_n).squeeze(0)
        target = self.ys[index]
        return img, target
    
    def get_path(self, index):
        return self.im_paths[index]

    def __len__(self):
        return len(self.ys)


In [3416]:
class TurtlesPathDataset(TurtlesDataset):
    def __init__(self, mode: str, root: str = None, transform=None, ignoreThreshold=0, data=None):
        super().__init__(mode, root, transform, ignoreThreshold, data)

    def __getitem__(self, index):
        return self.im_paths[index], self.ys[index]

    def __len__(self):
        return len(self.im_paths)


## Get Labels from Dataset/DataLoader

In [3417]:
def get_dataloader_labels(dataloader):
    """
    This function takes a PyTorch DataLoader and returns a tensor containing all the labels.
    
    Args:
        dataloader (torch.utils.data.DataLoader): A DataLoader object containing the dataset.

    Returns:
        torch.Tensor: A tensor containing all the labels.
    """
    all_labels = []

    for _, labels in dataloader:
        all_labels.append(labels)

    # Concatenate all labels into a single tensor
    return torch.cat(all_labels, dim=0)


## Resnet Implementation

In [3418]:


class Resnet101(nn.Module):
    def __init__(self,embedding_size, pretrained=True, is_norm=True, bn_freeze = True):
        super(Resnet101, self).__init__()

        self.model = resnet101(pretrained)
        self.is_norm = is_norm
        self.embedding_size = embedding_size
        self.num_ftrs = self.model.fc.in_features
        self.model.gap = nn.AdaptiveAvgPool2d(1)
        self.model.gmp = nn.AdaptiveMaxPool2d(1)

        self.model.embedding = nn.Linear(self.num_ftrs, self.embedding_size)
        self._initialize_weights()

        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)

    def l2_norm(self,input):
        input_size = input.size()
        buffer = torch.pow(input, 2)

        normp = torch.sum(buffer, 1).add_(1e-12)
        norm = torch.sqrt(normp)

        _output = torch.div(input, norm.view(-1, 1).expand_as(input))

        output = _output.view(input_size)

        return output

    def forward(self, x):
        """
        Crop and return a batch of cropped images from the input tensor batch using YOLOv11.
        If no object is detected, return the original image for the respective batch item.

        Args:
            image_batch (torch.Tensor): Input tensor of shape (B, C, H, W).
            The image batch should be normalized to 
            resnet_mean = [0.485, 0.456, 0.406]
            resnet_std = [0.229, 0.224, 0.225]

        Returns:
            torch.Tensor: Tensor of cropped images or original images if no detections.
        """
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        avg_x = self.model.gap(x)
        max_x = self.model.gmp(x)

        x = max_x + avg_x
        x = x.view(x.size(0), -1)
        x = self.model.embedding(x)
        
        if self.is_norm:
            x = self.l2_norm(x)
            
        return x

    def _initialize_weights(self):
        init.kaiming_normal_(self.model.embedding.weight, mode='fan_out')
        init.constant_(self.model.embedding.bias, 0)


## Resnet101 Model

In [3419]:
resnet101_model = Resnet101(embedding_size=512, pretrained=True, is_norm=True, bn_freeze = True)

# Path to checkpoint
checkpoint = torch.load(resnet101_checkpoint_path)
resnet101_model.load_state_dict(checkpoint)
resnet101_model.to(device)
resnet101_model.eval()



Resnet101(
  (model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          

## YOLO Detection 

In [3420]:
print(yolo_checkpoint_path)
yolo_model = YOLO(yolo_checkpoint_path).cuda()  # You can specify a different model if needed

/home/delta/Documents/Turtles/dataset_ops/Tag-A-Turtle/weights/best.pt


## Batch Process image tensors with YOLO
The function below takes a image tensor batch. The images are then cropped around the turtle head. 
If no turtle head is found the original image is used.

The function returns the tensor batch, cropped around the turtles' heads. 

In [3421]:
import torch
from PIL import Image, ImageDraw
from torchvision import transforms
import matplotlib.pyplot as plt

def batch_crop_yolov11(image_batch: torch.Tensor, model: YOLO, conf_threshold: float = 0.5, device: str = 'cuda'):
    global error_crops
    """
    Crop and return a batch of cropped images from the input tensor batch using YOLOv11.
    If no object is detected, return the original image for the respective batch item.
    Additionally, display the image with bounding boxes if there are more than 2 boxes.

    Args:
        image_batch (torch.Tensor): Input tensor of shape (B, C, H, W).
        model (YOLO): Preloaded YOLOv11 model.
        conf_threshold (float): Confidence threshold for detections.
        device (str): Device to run the model on ('cuda' or 'cpu').

    Returns:
        torch.Tensor: Tensor of cropped images or original images if no detections.
    """
    # Ensure the model is on the correct device
    model.to(device)

    # Load and preprocess the image batch
    images = image_batch.to(device)
    print("images shape",images.shape)
    # Perform inference
    results = model.predict(source=images, conf=conf_threshold, device=device, verbose=False)

    # Initialize a list to store cropped images
    cropped_images = []

    # Iterate over each result
    for result, image in zip(results, images):
        if result.boxes is not None and len(result.boxes.xyxy) > 0:
            boxes = result.boxes.xyxy.int().cpu().tolist()

            # if len(boxes) > 2:
            #     # Convert the image to PIL format for display
            #     pil_image = transforms.ToPILImage()(image.cpu())
            #     draw = ImageDraw.Draw(pil_image)
            #     # Draw each box on the image
            #     for box in boxes:
            #         x_min, y_min, x_max, y_max = box
            #         draw.rectangle([x_min, y_min, x_max, y_max], outline="red", width=3)

            #     # Show the image with bounding boxes
            #     plt.imshow(pil_image)
            #     plt.show()
            if len(boxes) != 1:
                error_crops = error_crops + 1
                
            # Crop the boxes and add them to the list
            x_min, y_min, x_max, y_max = boxes[0]   
            cropped_image = image[:, y_min:y_max, x_min:x_max]
            img = resize_images_to_n_padded(cropped_image.unsqueeze(0), resize_scale)
            cropped_images.append(img.squeeze(0))
                
        else:
            # If no detections, append the original image to the batch
            img = resize_images_to_n_padded(image.unsqueeze(0), resize_scale)
            cropped_images.append(img.squeeze(0))

    # Stack the cropped images into a single tensor
    cropped_images_tensor = torch.stack(cropped_images)

    return cropped_images_tensor


## Single Yolo Crop

In [3422]:
def single_crop_yolov11(image_name: str, model: YOLO, conf_threshold: float = 0.5, device: str = 'cuda'):
    global error_crops
    """
    Crop and return a batch of cropped images from the input tensor batch using YOLOv11.
    If no object is detected, return the original image for the respective batch item.
    Additionally, display the image with bounding boxes if there are more than 2 boxes.

    Args:
        image_batch (torch.Tensor): Input tensor of shape (B, C, H, W).
        model (YOLO): Preloaded YOLOv11 model.
        conf_threshold (float): Confidence threshold for detections.
        device (str): Device to run the model on ('cuda' or 'cpu').

    Returns:
        torch.Tensor: Tensor of cropped images or original images if no detections.
    """
    # Load and preprocess the image batch
    # Perform inference
    results = model.predict(source=image_name, conf=conf_threshold, device=device, verbose=False)
    img = Image.open(image_name).convert("RGB")

    # Initialize a list to store cropped images
    for result in results:
        if result.boxes is not None and len(result.boxes.xyxy) > 0:
            # Get the total number of boxes
            num_boxes = len(result.boxes.xyxy)
            if num_boxes != 1:
                error_crops = error_crops + 1
            # Get the first detected box (if there is at least one detection)
            first_box = result.boxes.xyxy[0].int().cpu().tolist()  # Convert to a list of integers (x_min, y_min, x_max, y_max)
            x_min, y_min, x_max, y_max = first_box

            img = img.crop((x_min, y_min, x_max, y_max))
        else: error_crops += 1
            
        # If no detections, append the original image to the batch            
    img = resize_image_to_n_padded(img, resize_scale)
    return img

In [3423]:
def batch_crop_name(image_names: torch.Tensor, model: YOLO, conf_threshold: float = 0.5, device: str = 'cuda'):
    cropped_images = []
    for image_name in image_names:
        cropped_image = single_crop_yolov11(image_name, model, conf_threshold, device)
        cropped_images.append(cropped_image)
    return torch.stack(cropped_images)

## Standard Transform before Embedding model

In [3424]:
def make_transform(original = 256, output=224):
    resnet_sz_resize = original
    resnet_sz_crop = output
    resnet_mean = [0.485, 0.456, 0.406]
    resnet_std = [0.229, 0.224, 0.225]
    resnet_transform = transforms.Compose([
        transforms.ToPILImage(),
        # transforms.Resize(resnet_sz_resize),
        transforms.CenterCrop(resnet_sz_crop),
        transforms.ToTensor(),
        transforms.Normalize(mean=resnet_mean, std=resnet_std)
    ])
    return resnet_transform
resnet_transform = make_transform(original=resize_scale)

## Resize, crop and normalize tensors for embedding model

In [3425]:
def pre_embed_transform(image_tensor_batch):
    processed_images = []
    for image_tensor in image_tensor_batch:
        processed_images.append(resnet_transform(image_tensor))
    return torch.stack(processed_images)
    

## Inference Pipeline
Handles tensorized images in batches, applying all the inferences and transforms required to produce the embeddings.

In [3426]:
def infer_pipeline(image_tensor):
    global yolo_model, resnet101_model, error_crops
    # Detect face in the image
    # faces_tensor = batch_crop_yolov11(image_tensor, yolo_model)
    faces_tensor = batch_crop_name(image_tensor, yolo_model, conf_threshold=0.5)
    # print("Number of cropping errors", error_crops)
    
    # crop and normalize before embedding model
    normalized_face_tensor = pre_embed_transform(faces_tensor).to(device)
    del faces_tensor
    
    # Generate embedding for the detected face
    embeddings = resnet101_model(normalized_face_tensor)
    del normalized_face_tensor  
    
    return embeddings.cpu()

In [3427]:
def getErrorRate():
    global error_crops
    return error_crops

## Function to run the inference pipeline over the full dataset

In [3428]:
def process_images_to_embeddings(data_loader):
    global inference_batch_size
    processed_batches = []
    pbar = tqdm(data_loader, desc="Processing data")
    index = 1
    for values in pbar:
        error_rate_percentage = (getErrorRate() / (index * inference_batch_size)) * 100
        formatted_error_rate = f"{error_rate_percentage:.3f}%"
        pbar.set_postfix({"Err Rate": formatted_error_rate})
        processed_batch = infer_pipeline(values[0])
        processed_batches.append(processed_batch.cpu())
        index += 1
    
    return torch.cat(processed_batches, dim=0).cpu()

In [3429]:

def turtle_similarities(query_embedding, gallery_embeddings):
    cosine_sim = F.cosine_similarity(query_embedding, gallery_embeddings)
    
    return cosine_sim


In [3430]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

def plot_recognition_results(query_image, similar_images, titles):
    # Increase the figure size to make images 3x larger than the previous size
    plt.figure(figsize=(90, 45))  # 3x larger than the previous figsize
    plt.subplot(1, len(similar_images)+1, 1)
    print(query_image[0])
    
    if type(query_image[0]) == str:
        query_image = mpimg.imread(query_image[0])
    plt.imshow(query_image)
    plt.title("Query")
    
    for idx, sim_img in enumerate(similar_images):
        plt.subplot(1, len(similar_images)+1, idx+2)
        if type(sim_img[0]) == str:
            sim_img = mpimg.imread(sim_img)
        plt.imshow(sim_img)
        plt.title(titles[idx])
    
    plt.show()


## Embeddings aggregator class

In [3431]:
class EmbeddingAggregator(nn.Module):
    """
    Average embeddings in the batch per turtle identity to produce more accurate embeddings.

    Parameters
    ----------
    reduction : str
        'mean' (default) or 'weighted'.
        When 'weighted', `weights` must be supplied in forward().
    return_index_map : bool
        If True, also returns a tensor that maps each original sample
        to its aggregated row – handy for gathering losses.
    """

    def __init__(self, reduction: str = "mean", return_index_map: bool = False):
        super().__init__()
        assert reduction in {"mean", "weighted"}
        self.reduction = reduction
        self.return_index_map = return_index_map

    @torch.no_grad()
    def forward(
        self,
        emb: torch.Tensor,             # shape (B, D)
        labels: torch.Tensor,          # shape (B,)
        weights: torch.Tensor | None = None
    ):
        device = emb.device
        labels = labels.to(device)

        # ---- gather bookkeeping -------------------------------------------------
        # unique ids and position of each sample’s aggregated row
        uniq_ids, inv = torch.unique(labels, return_inverse=True)
        num_ids = uniq_ids.size(0)

        if self.reduction == "mean":
            # quick path – use scatter_add then divide by counts
            summed = torch.zeros(num_ids, emb.size(1), device=device).scatter_add_(0,
                      inv.unsqueeze(-1).expand_as(emb), emb)
            counts = torch.bincount(inv, minlength=num_ids).unsqueeze(1)
            agg = summed / counts.clamp_min(1)

        else:  # weighted mean
            assert weights is not None, "`weights` required for weighted reduction"
            weights = weights.to(device).unsqueeze(1)  # (B, 1)
            w_sum = torch.zeros(num_ids, 1, device=device).scatter_add_(0, inv.unsqueeze(-1), weights)
            w_emb = torch.zeros_like(summed).scatter_add_(0, inv.unsqueeze(-1).expand_as(emb), emb * weights)
            agg = w_emb / w_sum.clamp_min(1e-8)

        if self.return_index_map:
            return agg, uniq_ids, inv
        return agg, uniq_ids, None



## Sort embeddings by similarity with their labels

In [3432]:
def sort_em_by_sim(similarities, query_labels):
    similarities_and_labels = [(similarity.item(), query_labels[i]) for i, similarity in enumerate(similarities)]

    # Sort the list by similarity in descending order
    similarities_and_labels_sorted = sorted(similarities_and_labels, key=lambda x: x[0], reverse=True)

    # Optionally, separate the sorted similarities and labels back into two lists/tensors
    sorted_similarities = [pair[0] for pair in similarities_and_labels_sorted]
    sorted_labels = [pair[1] for pair in similarities_and_labels_sorted]

    return sorted_similarities, sorted_labels

## Retrieve the images in a dataset that match the label

In [3433]:
def find_matching_indices(string_list, query):
    # Create a list of indices where the string matches the query
    matching_indices = [i for i, s in enumerate(string_list) if s == query]
    return matching_indices

## Basic tensor transform

In [3434]:
to_tensor_transform = transforms.Compose([
    transforms.ToTensor(),
])

## Main Pipeline Cell

In [None]:
with torch.no_grad():
    embeddings_exist = os.path.exists(saved_embeddings_path)
    error_crops = 0
    # Load dataset embeddings
    dataset_embeddings = None
    if embeddings_exist:
        dataset_embeddings, dataset_labels = torch.load(saved_embeddings_path)
    else:
        generate_data_folder_metadata(dataset_path)
        dataset = TurtlesPathDataset(root=dataset_path, mode="all", transform=to_tensor_transform)
        data_loader = DataLoader(dataset, batch_size=inference_batch_size, pin_memory=True, shuffle=False, num_workers=1)
        dataset_embeddings = process_images_to_embeddings(data_loader)
        dataset_labels = get_dataloader_labels(data_loader)
        torch.save((dataset_embeddings, dataset_labels), saved_embeddings_path)

    # Load query embeddings
    query_metadata = generate_query_folder_metadata(query_folder_path)
    print(f"Found {len(query_metadata)} query images")
    query_dataset = TurtlesPathDataset(data=query_metadata, mode="all", transform=to_tensor_transform)
    query_loader = DataLoader(query_dataset, batch_size=inference_batch_size, pin_memory=True, shuffle=False, num_workers=1)
    query_embeddings = process_images_to_embeddings(query_loader)
    query_labels = get_dataloader_labels(query_loader)

    # Aggregate embeddings by identity if enabled
    aggregator = EmbeddingAggregator().to(device)
    if aggregate_dataset_identities:
        dataset_embeddings, dataset_labels, _ = aggregator(dataset_embeddings, labels=dataset_labels)
    if aggregate_query_identities:
        query_embeddings, query_labels, _ = aggregator(query_embeddings, labels=query_labels)
        
    for query_index in range(query_embeddings.shape[0]):
        query_em = query_embeddings[query_index]
        query_label = query_labels[query_index]
        
        # Get the similarities between query and dataset embeddings
        similarities = turtle_similarities(query_em, dataset_embeddings)
        
        # Extract image paths from dataset
        image_locations = [dataset[i][0] for i in range(len(dataset))]  # image paths

        # Sort the similarities and get the top N results
        top_n = min(top_n, len(similarities))
        sorted_similarities, sorted_labels = sort_em_by_sim(similarities, image_locations)[:top_n]
        
        # Prepare images and filenames to be plotted
        images = []
        images_labels = []
        for label in sorted_labels:
            images.append(label)  # image path
            images_labels.append(os.path.basename(label))  # Get the filename from the image path
        # Find the query image path and plot results
        query_image_index = find_matching_indices(query_labels, query_label)[0]
        query_image = query_dataset[query_image_index]
        plot_recognition_results(query_image, images, images_labels)

    

Dataset size: 8396
Data length: 8396


Processing data:  37%|███▋      | 193/525 [03:03<09:35,  1.73s/it, Err Rate=1.289%]