# Google Colab Turtle Recognition Pipeline - Tag a Turtle
## 1. Introduction
The goal of this pipeline is to identify sea turtles by their facial features using machine learning models. It processes turtle face images, creates embeddings for these faces, and matches them to known turtle identities.

To learn more about the project and technical aspect of the pipeline, please read the README.md

## 2. Installation instructions
**Step 1**: Download this pipeline.ipynb file and upload it to your drive.

**Step 2**: Upload the dataset to your drive if you haven't already.

**Step 3**: Create a folder for images you want to indentify in your drive and upload images there.

**Step 4**: Download and upload the fine-tuned model weights for the models used in this pipeline. To get a copy of the model weights contact your developers or dan.rayu@gmail.com.

**Step 5**: Open the pipeline.ipynb file and fill in paths for the image dataset, the identify folder, the model weights. You will be given straightforward errors if some of these files can't be found.

Congrats! You have done all the necessary preparations to use the pipeline.

## 3. Usage instructions:
**Step 1**: Insert images of unknown turtles into the predict folder found on your drive.

**Step 2**:  Make sure this notebook is connected to the GPU

* Check the top right corner. If next to the ✅ it says “T4” or “Connect GPU” you are good to go
* If it says nothing, or anything other than T4 or GPU, you need to click the small arrow (▼) and select “Change runtime type”. Then in the pop-up menu you can select T4 GPU and save.

**Step 3**: Now click on the play button (Run all) next to the text to start the code

**Step 4**: A pop-up will show saying you need to connect to drive - accept

**Step 5**: Scroll to the bottom wait for the results to show up.

## Extra troubleshooting
If the system keeps crashing, try deleting the RecognitionCache folder and all of it's contents.

## Configuration
### Process configuration
Parameters that impact recognition results

In [1005]:

# The resized image to center crop from for the embedding model
resize_scale = 256

# Resize scale before passing to Detection model
full_resize_n = 1920

# Combine all dataset images of the same turtle into one for improved performance
aggregate_dataset_identities = False

# Combine all query images of the same turtle into one for improved performance
aggregate_query_identities = False

# The amount of most similar images to be shown (top 10, or top 5 most similar turtles)
top_n = 10

### Parallelism configuration
Options that impact system speed and resource usage

In [1006]:
# Larger batch sizes speed up processing but may overflow memory.
inference_batch_size = 4

use_gpu = True

## Mount Drive
Uncomment for Collab

In [1007]:
# @title
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)

### Paths configuration for Google Collab
Uncomment before using Google Collab

In [1008]:
# data_path_config = "Data"
# yolo_weights_config = 'best.pt'
# embedding_weights_config = 'best_fresh-field-240.pth'
# predict_folder_config = 'predict'
# basedir = "/content/drive/MyDrive/AI_for_Turtles"

# import os

# # Adjust these only for developers
# dataset_path = os.path.join(basedir, data_path_config)
# if not os.path.exists(dataset_path):
#     raise FileNotFoundError(f"Dataset path {dataset_path} does not exist.")
# yolo_checkpoint_path = os.path.join(basedir, yolo_weights_config)
# if not os.path.exists(yolo_checkpoint_path):
#     raise FileNotFoundError(f"Yolo checkpoint path {yolo_checkpoint_path} does not exist.")
# embedding_checkpoint_path = os.path.join(basedir, embedding_weights_config)
# if not os.path.exists(embedding_checkpoint_path):
#     raise FileNotFoundError(f"Resnet101 checkpoint path {embedding_checkpoint_path} does not exist.")
# query_folder_path = os.path.join(basedir, predict_folder_config)
# if not os.path.exists(query_folder_path):
#     raise FileNotFoundError(f"Query folder path {query_folder_path} does not exist.")

# # Path of the embedded dataset. This embedding file has to be regenerated if the model is updated.
# saved_embeddings_path = os.path.join(basedir, "RecognitionCache/embeddings/embeddings.pth")
# # Path of the generated metadata
# auto_metadata_path =  os.path.join(basedir, f"RecognitionCache/datasets/auto_{dataset_path.split('/')[-1]}.csv")
# metadata_path = None

## Path configuration for PC
Comment for Google Collab running

In [1009]:
dataset_path = "/home/delta/Documents/Turtles/dataset_May15th/train/reid"
yolo_checkpoint_path = '/home/delta/Documents/Turtles/dataset_ops/Tag-A-Turtle/weights/best.pt'
embedding_checkpoint_path = "/home/delta/Documents/Turtles/inference_pipeline/best_fresh-field-240.pth"
side_detector_checkpoint_path = "/home/delta/Documents/Turtles/inference_pipeline/checkpoint_epoch_40.pth"
query_folder_path = "/home/delta/Documents/Turtles/Proxy-Anchor/query_images"

# Path of the embedded dataset. This embedding file has to be regenerated if the model is updated.
saved_embeddings_path = "./RecognitionCache/embeddings/embeddings.pth"
# Path of the generated metadata
auto_metadata_path = f"/home/delta/Documents/Turtles/inference_pipeline/RecognitionCache/datasets/auto_{dataset_path.split('/')[-1]}.csv"
# Path of images in the same order as the embeddings
embedding_image_path_list_path = f"/home/delta/Documents/Turtles/inference_pipeline/RecognitionCache/embeddings/embedding_image_path_list.csv"
metadata_path = None

### Testing configuration
Only for developers

In [1010]:
limit_dataset_images_testing = False

In [1011]:
# testing
error_crops = 0
def set_error_crops(n):
    global error_crops
    error_crops = n

## Install dependencies
Uncomment for Google Collab

In [1012]:
# @title
# %%capture
# !pip install torch tqdm torchvision pandas matplotlib Pillow ultralytics

## Imports

In [1013]:
import torch
from tqdm import tqdm
import torch.nn.init as init
import os
import pandas as pd
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
import matplotlib.image as mpimg
from PIL import Image, ImageOps
from ultralytics import YOLO
from transformers import Dinov2Model
import matplotlib.pyplot as plt
import numpy as np
import re
# determine device

## Set the device on which to run

In [1014]:
# Set the device to 'cuda' if a GPU is available, otherwise default to 'cpu'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = device if use_gpu else "cpu"
if use_gpu and device == "cpu":
    print("GPU not available.")
# Verify the device being used
print(f"Using device: {device}")

Using device: cuda


## Utils

In [1015]:
def extract_labels(file_path):
    # Extract the filename from the full file path
    filename = os.path.basename(file_path)

    # Check for format 1: "03_066_R_2003_1.jpg" (any extension, underscores included)
    match1 = re.match(r"(\d{2}[-_]\d{3})_([A-Za-z])_\d{4}_\d+\.[a-zA-Z]+$", filename)
    if match1:
        return [match1.group(1), match1.group(2)]

    # Check for format 2: "03_066 R.JPG" (any extension, underscores included)
    match2 = re.match(r"(\d{2}[-_]\d{3})\s([A-Za-z])\.[a-zA-Z]+$", filename)

    match3 = re.match(r"(.*\s)?(\d{2}[-_]\d{3})\s([A-Za-z])\.[a-zA-Z]+$", filename)

    if match2:
        return [match2.group(1), match2.group(2)]

    if match3:
      return [match3.group(1), match3.group(2)]

    # Return None if no match found
    return None

In [1016]:
os.makedirs("./RecognitionCache/datasets", exist_ok=True)
os.makedirs("./RecognitionCache/embeddings", exist_ok=True)


## Generating a metadata for the dataset

In [1017]:
def generate_data_folder_metadata(root_path):
    file_paths = []

    # Walk through the root directory and its subdirectories
    for subdir, _, files in os.walk(root_path):
        for file in files:
            # Join the subdirectory and file name to get the full path
            file_paths.append(os.path.join(subdir, file))

    data = []
    for file_path in file_paths:
        labels = extract_labels(file_path)
        # Add the row to the new format data
        if labels:
          data.append({
              'bonaire_turtle_id': labels[0],
              'side': labels[1],
              'filename': file_path,
          })
    if limit_dataset_images_testing:
        data = data[:16]
    df = pd.DataFrame(data)
    return df

## Generating metadata for the query files

In [1018]:
def generate_query_folder_metadata(root_path):
    data = []

    # Walk through the root directory and its subdirectories
    for subdir, _, files in os.walk(root_path):
        for file in files:
            # Join the subdirectory and file name to get the full path
            data.append({
                'bonaire_turtle_id': f"{subdir}",
                'side': "U",
                'filename': os.path.join(subdir, file),
            })
    df = pd.DataFrame(data)
    return df

## Make sure there is metadata

In [1019]:
def ensure_metadata():
    data_path = None

    # Try loading user provided metadata
    if metadata_path:
        data_path = metadata_path

    # If it fails, load auto metadata
    else:
        data_path = auto_metadata_path
    data_exists = os.path.isfile(data_path)

    # If there is no metadata, generate it
    if not data_exists:
        print(f"There is no metadata file {data_path}")
        generate_data_folder_metadata(dataset_path)
        data_path = auto_metadata_path

## Resize image tensor batch from any size to size N

In [1020]:
def resize_image_t_to_n_padded(image_batch, n=224):
    """
    Takes a batch of image tensors, resizes the smaller dimension to 'n', and pads
    the larger dimension with black to fit a square of size n x n. Returns a batch of
    images in tensor form.

    Args:
        image_batch (torch.Tensor): A tensor batch of images with shape (B, C, H, W), where B is the batch size,
                                   C is the number of channels (3 for RGB), H is the height, and W is the width.
        n (int): The target size for the square image (default 224).
    Returns:
        torch.Tensor: A tensor batch of images, each of size n x n.
    """
    transform = transforms.Compose([
        transforms.Resize((n, n)),
        transforms.ToTensor()
    ])

    batch_images = []
    for img_tensor in image_batch:
        img = transforms.ToPILImage()(img_tensor)

        width, height = img.size
        if width == height:
            # If the image is already square, simply resize
            transformed_image = transform(img)
        elif width < height:
            # Resize height to n, calculate padding for width
            new_height = n
            new_width = int(width * (new_height / height))
            padding_left = (n - new_width) // 2
            padding_right = n - new_width - padding_left
            transformed_image = transforms.Compose([
                transforms.Resize((new_height, new_width)),
                transforms.Pad((padding_left, 0, padding_right, 0), fill=0, padding_mode='constant'),
                transforms.ToTensor()
            ])(img)
        else:  # height < width
            # Resize width to n, calculate padding for height
            new_width = n
            new_height = int(height * (new_width / width))
            padding_top = (n - new_height) // 2
            padding_bottom = n - new_height - padding_top
            transformed_image = transforms.Compose([
                transforms.Resize((new_height, new_width)),
                transforms.Pad((0, padding_top, 0, padding_bottom), fill=0, padding_mode='constant'),
                transforms.ToTensor()
            ])(img)

        batch_images.append(transformed_image)

    # Stack images into a single tensor batch
    return torch.stack(batch_images)


In [1021]:
def resize_image_to_n_padded(image_pil, n=224):
    transform = transforms.Compose([
        transforms.Resize((n, n)),
        transforms.ToTensor()
    ])

    img = image_pil

    width, height = img.size
    if width == height:
        # If the image is already square, simply resize
        transformed_image = transform(img)
    elif width < height:
        # Resize height to n, calculate padding for width
        new_height = n
        new_width = int(width * (new_height / height))
        padding_left = (n - new_width) // 2
        padding_right = n - new_width - padding_left
        transformed_image = transforms.Compose([
            transforms.Resize((new_height, new_width)),
            transforms.Pad((padding_left, 0, padding_right, 0), fill=0, padding_mode='constant'),
            transforms.ToTensor()
        ])(img)
    else:  # height < width
        # Resize width to n, calculate padding for height
        new_width = n
        new_height = int(height * (new_width / width))
        padding_top = (n - new_height) // 2
        padding_bottom = n - new_height - padding_top
        transformed_image = transforms.Compose([
            transforms.Resize((new_height, new_width)),
            transforms.Pad((0, padding_top, 0, padding_bottom), fill=0, padding_mode='constant'),
            transforms.ToTensor()
        ])(img)


    # Stack images into a single tensor batch
    return transformed_image


## Dataset Class

In [1022]:

class TurtlesPathDataset(Dataset):
    def __init__(self, mode: str, root: str = None, transform=None, ignoreThreshold=0, data=None):
        self.mode = mode.lower()
        self.transform = transform
        if data is None:
            data_path = metadata_path if metadata_path else auto_metadata_path
            self.root = root
            meta_path = os.path.join(self.root, data_path)
            if not os.path.isfile(meta_path):
                raise FileNotFoundError(f"No metadata file found (expected {data_path}).")
            data_df = pd.read_csv(meta_path)
        else:
            data_df = data

        if limit_dataset_images_testing:
            cutoff = min(16, len(data_df))
            data_df = data_df[:cutoff]

        self.im_paths, self._y_strs, self.positions = [], [], []
        for _, (turtle_id, side, filename) in data_df.iterrows():
                img_path = filename
                if not os.path.isfile(img_path):
                    continue

                identity = f"{turtle_id}"
                self.im_paths.append(img_path)
                self._y_strs.append(identity)
                self.positions.append(side)

        if not self.im_paths:
            raise RuntimeError("Dataset is empty.")

        all_classes = set(self._y_strs)

        self.classes = all_classes
        self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
        self.ys = [self.class_to_idx[s] for s in self._y_strs]

    def __getitem__(self, index):
        return self.im_paths[index], self.ys[index]

    def get_path(self, index):
        return self.im_paths[index]

    def __len__(self):
        return len(self.im_paths)


## Get Labels from Dataset/DataLoader

In [1023]:
def get_dataloader_values(dataloader):
    """
    This function takes a PyTorch DataLoader and returns a tensor containing all the labels.

    Args:
        dataloader (torch.utils.data.DataLoader): A DataLoader object containing the dataset.

    Returns:
        torch.Tensor: A tensor containing all the labels.
    """
    final_labels = [], []
    all_labels = []

    for paths, labels in dataloader:
        all_labels.extend(labels)

    final_labels = torch.tensor(all_labels)
    return final_labels


## Face Side Detection Implementation
A resnet18 trained to detect turtle face sides.

In [1024]:
class FaceSideModel(nn.Module):
    def __init__(self, pretrained=True, is_norm=True, bn_freeze=True):

        super(FaceSideModel, self).__init__()

        self.model = resnet18(pretrained=pretrained)

        self.is_norm = is_norm
        self.num_ftrs = self.model.fc.in_features
        self.model.gap = nn.AdaptiveAvgPool2d(1)
        self.model.gmp = nn.AdaptiveMaxPool2d(1)

        # Replace the last fully connected layer with a new one for 2 class output
        self.model.fc = nn.Linear(self.num_ftrs, 2)  # 2 classes: L and R

        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)

        self._initialize_weights()

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        avg_x = self.model.gap(x)
        max_x = self.model.gmp(x)

        x = avg_x + max_x  # Combine average and max pooling features

        x = x.view(x.size(0), -1)
        x = self.model.fc(x)  # Final output layer

        if self.is_norm:
            x = self.l2_norm(x)

        return x

    def l2_norm(self, input):
        input_size = input.size()
        buffer = torch.pow(input, 2)

        normp = torch.sum(buffer, 1).add_(1e-12)
        norm = torch.sqrt(normp)

        _output = torch.div(input, norm.view(-1, 1).expand_as(input))

        output = _output.view(input_size)
        return output

    def _initialize_weights(self):
        # Only initialize the newly added fc layer
        if hasattr(self.model, 'fc') and isinstance(self.model.fc, nn.Linear):
            init.kaiming_normal_(self.model.fc.weight, mode='fan_out')
            init.constant_(self.model.fc.bias, 0)

## Resnet Implementation

In [1025]:
class Resnet101(nn.Module):
    def __init__(self,embedding_size, pretrained=True, is_norm=True, bn_freeze = True):
        super(Resnet101, self).__init__()

        self.model = resnet101(pretrained)
        self.is_norm = is_norm
        self.embedding_size = embedding_size
        self.num_ftrs = self.model.fc.in_features
        self.model.gap = nn.AdaptiveAvgPool2d(1)
        self.model.gmp = nn.AdaptiveMaxPool2d(1)

        self.model.embedding = nn.Linear(self.num_ftrs, self.embedding_size)
        self._initialize_weights()

        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)

    def l2_norm(self,input):
        input_size = input.size()
        buffer = torch.pow(input, 2)

        normp = torch.sum(buffer, 1).add_(1e-12)
        norm = torch.sqrt(normp)

        _output = torch.div(input, norm.view(-1, 1).expand_as(input))

        output = _output.view(input_size)

        return output

    def forward(self, x):
        """
        Crop and return a batch of cropped images from the input tensor batch using YOLOv11.
        If no object is detected, return the original image for the respective batch item.

        Args:
            image_batch (torch.Tensor): Input tensor of shape (B, C, H, W).
            The image batch should be normalized to
            resnet_mean = [0.485, 0.456, 0.406]
            resnet_std = [0.229, 0.224, 0.225]

        Returns:
            torch.Tensor: Tensor of cropped images or original images if no detections.
        """
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        avg_x = self.model.gap(x)
        max_x = self.model.gmp(x)

        x = max_x + avg_x
        x = x.view(x.size(0), -1)
        x = self.model.embedding(x)

        if self.is_norm:
            x = self.l2_norm(x)

        return x

    def _initialize_weights(self):
        init.kaiming_normal_(self.model.embedding.weight, mode='fan_out')
        init.constant_(self.model.embedding.bias, 0)


## Dino Implementation

In [1026]:

class Dinov2Small(nn.Module):
    def __init__(self, embedding_size, pretrained=True, is_norm=True, bn_freeze=True):
        super(Dinov2Small, self).__init__()

        self.is_norm = is_norm
        self.embedding_size = embedding_size

        # Load pretrained DINOv2 backbone
        self.model = Dinov2Model.from_pretrained("facebook/dinov2-small")

        # Feature size from CLS token
        self.num_ftrs = self.model.config.hidden_size

        # Projection head
        self.model.embedding = nn.Linear(self.num_ftrs, self.embedding_size)
        self._initialize_weights()

    def l2_norm(self, input):
        input_size = input.size()
        buffer = torch.pow(input, 2)
        normp = torch.sum(buffer, 1).add_(1e-12)
        norm = torch.sqrt(normp)
        output = torch.div(input, norm.view(-1, 1).expand_as(input))
        return output.view(input_size)

    def forward(self, x):
        # Get CLS token
        x = self.model(pixel_values=x).last_hidden_state[:, 0]  # shape: (B, hidden_size)

        x = self.model.embedding(x)

        if self.is_norm:
            x = self.l2_norm(x)

        return x

    def _initialize_weights(self):
        init.kaiming_normal_(self.model.embedding.weight, mode='fan_out')
        init.constant_(self.model.embedding.bias, 0)


## Embedding Model

In [1027]:
# embedding_model = Resnet101(embedding_size=512, pretrained=True, is_norm=True, bn_freeze = True)
embedding_model = Dinov2Small(embedding_size=512, pretrained=True, is_norm=True, bn_freeze = True)
# Path to checkpoint
checkpoint = torch.load(embedding_checkpoint_path)
embedding_model.load_state_dict(checkpoint)
embedding_model.to(device)
embedding_model.eval()

Dinov2Small(
  (model): Dinov2Model(
    (embeddings): Dinov2Embeddings(
      (patch_embeddings): Dinov2PatchEmbeddings(
        (projection): Conv2d(3, 384, kernel_size=(14, 14), stride=(14, 14))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Dinov2Encoder(
      (layer): ModuleList(
        (0-11): 12 x Dinov2Layer(
          (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
          (attention): Dinov2SdpaAttention(
            (attention): Dinov2SdpaSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): Dinov2SelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
   

## YOLO Detection

In [1028]:
print(yolo_checkpoint_path)
yolo_model = YOLO(yolo_checkpoint_path).cuda()  # You can specify a different model if needed

/home/delta/Documents/Turtles/dataset_ops/Tag-A-Turtle/weights/best.pt


## Single Yolo Crop

In [1029]:
def single_crop_yolov11(image_name: str, model: YOLO, conf_threshold: float = 0.5, device: str = 'cuda'):
    global error_crops
    """
    Crop and return a batch of cropped images from the input tensor batch using YOLOv11.
    If no object is detected, return the original image for the respective batch item.
    Additionally, display the image with bounding boxes if there are more than 2 boxes.

    Args:
        image_batch (torch.Tensor): Input tensor of shape (B, C, H, W).
        model (YOLO): Preloaded YOLOv11 model.
        conf_threshold (float): Confidence threshold for detections.
        device (str): Device to run the model on ('cuda' or 'cpu').

    Returns:
        torch.Tensor: Tensor of cropped images or original images if no detections.
    """
    # Load and preprocess the image batch
    # Perform inference
    results = model.predict(source=image_name, conf=conf_threshold, device=device, verbose=False)
    img = Image.open(image_name).convert("RGB")
    img = ImageOps.exif_transpose(img)
    # Initialize a list to store cropped images
    for result in results:
        if result.boxes is not None and len(result.boxes.xyxy) > 0:
            # Get the total number of boxes
            num_boxes = len(result.boxes.xyxy)
            if num_boxes != 1:
                error_crops = error_crops + 1
            # Get the first detected box (if there is at least one detection)
            first_box = result.boxes.xyxy[0].int().cpu().tolist()  # Convert to a list of integers (x_min, y_min, x_max, y_max)
            x_min, y_min, x_max, y_max = first_box

            img = img.crop((x_min, y_min, x_max, y_max))
        else: error_crops += 1

    img = resize_image_to_n_padded(img, resize_scale)
    return img

In [1030]:
def batch_crop_name(image_names: torch.Tensor, model: YOLO, conf_threshold: float = 0.5, device: str = 'cuda'):
    cropped_images = []
    for image_name in image_names:
        cropped_image = single_crop_yolov11(image_name, model, conf_threshold, device)
        cropped_images.append(cropped_image)

    return torch.stack(cropped_images)

## Standard Transform before Embedding model

In [1031]:
def make_transform(original = 256, output=224):
    resnet_sz_resize = original
    resnet_sz_crop = output
    resnet_mean = [0.485, 0.456, 0.406]
    resnet_std = [0.229, 0.224, 0.225]
    resnet_transform = transforms.Compose([
        transforms.ToPILImage(),
        # transforms.Resize(resnet_sz_resize),
        transforms.CenterCrop(resnet_sz_crop),
        transforms.ToTensor(),
        transforms.Normalize(mean=resnet_mean, std=resnet_std)
    ])
    return resnet_transform
resnet_transform = make_transform(original=resize_scale)

## Resize, crop and normalize tensors for embedding model

In [1032]:
def pre_embed_transform(image_tensor_batch):
    processed_images = []
    for image_tensor in image_tensor_batch:
        processed_images.append(resnet_transform(image_tensor))
    return torch.stack(processed_images)


## Inference Pipeline
Handles tensorized images in batches, applying all the inferences and transforms required to produce the embeddings.

In [1033]:
def infer_pipeline(image_tensor):
    global yolo_model, embedding_model, error_crops
    # Detect face in the image
    # faces_tensor = batch_crop_yolov11(image_tensor, yolo_model)
    faces_tensor = batch_crop_name(image_tensor, yolo_model, conf_threshold=0.5)
    # print("Number of cropping errors", error_crops)

    # crop and normalize before embedding model
    normalized_face_tensor = pre_embed_transform(faces_tensor).to(device)
    del faces_tensor

    # Generate embedding for the detected face
    embeddings = embedding_model(normalized_face_tensor)
    del normalized_face_tensor

    return embeddings.cpu()

In [1034]:
def getErrorRate():
    global error_crops
    return error_crops

## Function to run the inference pipeline over the full dataset

In [1035]:
def process_images_to_embeddings(data_loader):
    global inference_batch_size
    processed_batches = []
    pbar = tqdm(data_loader, desc="Processing data")
    index = 1
    for values in pbar:
        error_rate_percentage = (getErrorRate() / (index * inference_batch_size)) * 100
        formatted_error_rate = f"{error_rate_percentage:.3f}%"
        pbar.set_postfix({"Err Rate": formatted_error_rate})
        processed_batch = infer_pipeline(values[0])
        processed_batches.append(processed_batch.cpu())
        index += 1

    return torch.cat(processed_batches, dim=0).cpu()

## Similarity functions

In [1036]:

def turtle_similarities(query_embedding, gallery_embeddings):
    cosine_sim = F.cosine_similarity(query_embedding, gallery_embeddings)
    return cosine_sim


In [1037]:
def similarity_score(similarity_value):
    if similarity_value < 0:
        return 0
    if similarity_value > 1:
        return 100
    return similarity_value * 100

In [1038]:
def pil_to_numpy(pil_image):
    return np.array(pil_image)

## Open plotlib images right side up

In [1039]:

def display_correctly(image_path):
    global yolo_model
    cropped_image_tensor = single_crop_yolov11(image_path, yolo_model)
    pil_query_image = transforms.ToPILImage()(cropped_image_tensor)
    # Convert back to NumPy array for matplotlib
    return np.array(pil_query_image)

## Show recognition results

In [1040]:

def plot_recognition_results(query_image, similar_images, titles, similarity):
    # Increase the figure size to make images 3x larger than the previous size
    plt.figure(figsize=(90, 45))  # 3x larger than the previous figsize
    plt.subplot(1, len(similar_images)+1, 1)
    print(query_image[0])

    if type(query_image[0]) == str:
        query_image = display_correctly(query_image[0])
    plt.imshow(query_image)
    plt.title("Query")

    for idx, sim_img in enumerate(similar_images):
        plt.subplot(1, len(similar_images)+1, idx+2)
        if type(sim_img[0]) == str:
            sim_img = display_correctly(sim_img)
        plt.imshow(sim_img)
        plt.title(titles[idx])

        # Display similarity below each similar image
        plt.text(0.5, -0.1, f"Similarity: {similarity_score(similarity[idx]):.1f}%", ha='center', va='top', transform=plt.gca().transAxes, fontsize=12)

    plt.show()

## Embeddings aggregator class

In [1041]:
class EmbeddingAggregator(nn.Module):
    """
    Average embeddings in the batch per turtle identity to produce more accurate embeddings.

    Parameters
    ----------
    reduction : str
        'mean' (default) or 'weighted'.
        When 'weighted', `weights` must be supplied in forward().
    return_index_map : bool
        If True, also returns a tensor that maps each original sample
        to its aggregated row – handy for gathering losses.
    """

    def __init__(self, reduction: str = "mean", return_index_map: bool = False):
        super().__init__()
        assert reduction in {"mean", "weighted"}
        self.reduction = reduction
        self.return_index_map = return_index_map

    @torch.no_grad()
    def forward(
        self,
        emb: torch.Tensor,             # shape (B, D)
        labels: torch.Tensor,          # shape (B,)
        weights: torch.Tensor | None = None
    ):
        device = emb.device
        labels = labels.to(device)

        # ---- gather bookkeeping -------------------------------------------------
        # unique ids and position of each sample’s aggregated row
        uniq_ids, inv = torch.unique(labels, return_inverse=True)
        num_ids = uniq_ids.size(0)

        if self.reduction == "mean":
            # quick path – use scatter_add then divide by counts
            summed = torch.zeros(num_ids, emb.size(1), device=device).scatter_add_(0,
                      inv.unsqueeze(-1).expand_as(emb), emb)
            counts = torch.bincount(inv, minlength=num_ids).unsqueeze(1)
            agg = summed / counts.clamp_min(1)

        else:  # weighted mean
            assert weights is not None, "`weights` required for weighted reduction"
            weights = weights.to(device).unsqueeze(1)  # (B, 1)
            w_sum = torch.zeros(num_ids, 1, device=device).scatter_add_(0, inv.unsqueeze(-1), weights)
            w_emb = torch.zeros_like(summed).scatter_add_(0, inv.unsqueeze(-1).expand_as(emb), emb * weights)
            agg = w_emb / w_sum.clamp_min(1e-8)

        if self.return_index_map:
            return agg, uniq_ids, inv
        return agg, uniq_ids, None



## Sort embeddings by similarity with their labels

In [1042]:
def sort_em_by_sim(similarities, query_labels):
    similarities_and_labels = [(similarity.item(), query_labels[i]) for i, similarity in enumerate(similarities)]

    # Sort the list by similarity in descending order
    similarities_and_labels_sorted = sorted(similarities_and_labels, key=lambda x: x[0], reverse=True)

    # Optionally, separate the sorted similarities and labels back into two lists/tensors
    sorted_similarities = [pair[0] for pair in similarities_and_labels_sorted]
    sorted_labels = [pair[1] for pair in similarities_and_labels_sorted]

    return sorted_similarities, sorted_labels

## Retrieve the images in a dataset that match the label

In [1043]:
def find_matching_indices(string_list, query):
    # Create a list of indices where the string matches the query
    matching_indices = [i for i, s in enumerate(string_list) if s == query]
    return matching_indices

## Retrieve cropped images from dataloader

## Basic tensor transform

In [1044]:
to_tensor_transform = transforms.Compose([
    transforms.ToTensor(),
])

## Path list

In [1045]:
def get_embedding_image_path_list():
    global embedding_image_path_list_path
    path = embedding_image_path_list_path
    if not os.path.isfile(path):
        return pd.DataFrame([])
    return pd.read_csv(path)

## Detect new images

In [1046]:
def get_new_images():
    old_metadata = None
    # determine new images
    if os.path.exists(auto_metadata_path):
        old_metadata = pd.read_csv(auto_metadata_path)
        print(old_metadata.head())

        old_dataset = TurtlesPathDataset(root=None, mode="all", transform=to_tensor_transform, data=old_metadata)

        new_metadata = generate_data_folder_metadata(dataset_path)
        new_dataset = TurtlesPathDataset(root=None, mode="all", transform=to_tensor_transform, data=new_metadata)

        new_images_metadata = pd.DataFrame(columns=['bonaire_turtle_id','side', 'filename'])
        new_images_metadata_index = 0
        for index in range(0, len(new_dataset)):
            row = new_dataset[index]
            if row[0] not in old_dataset.im_paths:
                # print new image
                print("pppp",[row[1], "U", row[0]])
                new_images_metadata.loc[new_images_metadata_index] = [row[1], "U", row[0]]
                new_images_metadata_index += 1
    else:
        new_images_metadata = generate_data_folder_metadata(dataset_path)
    return new_images_metadata

## Generate embeddings for new images

In [1047]:
def embed_images_from_df(df):
    new_images_dataset = TurtlesPathDataset(data=df, mode="all", transform=to_tensor_transform)
    new_images_data_loader = DataLoader(new_images_dataset, batch_size=inference_batch_size, pin_memory=True, shuffle=False, num_workers=1)
    image_labels = get_dataloader_values(new_images_data_loader)
    new_images_embeddings = process_images_to_embeddings(new_images_data_loader)
    
    img_paths = new_images_dataset.im_paths
    return new_images_embeddings, image_labels, pd.DataFrame(img_paths)

## Main Pipeline Cell

In [1048]:
import sys
with torch.no_grad():
    error_crops = 0
    # Load dataset embeddings
    dataset_embeddings = None
    dataset_image_embeddings, dataset_embedded_image_labels = torch.tensor([]), torch.tensor([])
    dataset_paths_list = pd.DataFrame([])
    new_images = get_new_images()
    embeddings_exist = os.path.exists(saved_embeddings_path)
    current_metadata = generate_data_folder_metadata(dataset_path)

    dataset = TurtlesPathDataset(root=None, mode="all", transform=to_tensor_transform, data=current_metadata)

    data_loader = DataLoader(dataset, batch_size=inference_batch_size, pin_memory=True, shuffle=False, num_workers=1)
    ne = torch.tensor([])
    if embeddings_exist:
        dataset_image_embeddings, dataset_embedded_image_labels = torch.load(saved_embeddings_path)
        dataset_paths_list = get_embedding_image_path_list()
        # torch.save((dataset_image_embeddings[:8396], dataset_embedded_image_labels[:8396]), saved_embeddings_path)
        # dataset_paths_list[:8396].to_csv(embedding_image_path_list_path, index=False)
        # sys.exit(0)

    # add new images
    if new_images.shape[0] > 0:
        new_images_embeddings, new_images_labels, new_images_paths = embed_images_from_df(new_images)

        temp_df_list = [dataset_paths_list, new_images_paths]
        dataset_paths_list = pd.concat(temp_df_list)
        dataset_image_embeddings = torch.cat([dataset_image_embeddings, new_images_embeddings], dim=0)
        dataset_embedded_image_labels = torch.cat([dataset_embedded_image_labels, new_images_labels], dim=0)

        torch.save((dataset_image_embeddings, dataset_embedded_image_labels), saved_embeddings_path)
        dataset_paths_list.to_csv(embedding_image_path_list_path, index=False)
        current_metadata.to_csv(auto_metadata_path, index=False)


    print("Dataset embeddings loaded.")
    # Load query embeddings
    query_metadata = generate_query_folder_metadata(query_folder_path)
    print(f"Found {len(query_metadata)} query images")
    query_dataset = TurtlesPathDataset(data=query_metadata, mode="all", transform=to_tensor_transform)
    query_loader = DataLoader(query_dataset, batch_size=inference_batch_size, pin_memory=True, shuffle=False, num_workers=1)
    query_embeddings = process_images_to_embeddings(query_loader)
    image_query_labels = get_dataloader_values(query_loader)

    # Aggregate embeddings by identity if enabled
    aggregator = EmbeddingAggregator().to(device)
    # Aggregating dataset identities underperforms, so it will be disabled for now.
    if aggregate_dataset_identities:
        dataset_image_embeddings, dataset_embedded_image_labels, _ = aggregator(dataset_image_embeddings, labels=dataset_embedded_image_labels)
    if aggregate_query_identities:
        query_embeddings, query_labels, _ = aggregator(query_embeddings, labels=image_query_labels)

    # tq1 = dataset_image_embeddings[-3].unsqueeze(0)
    # tq2 = query_embeddings[-1]
    # similarities = turtle_similarities(tq2, tq1)
    # print(similarities)
    # sys.exit(0)
    # Get similarities between query and dataset embeddings
    for query_index in range(query_embeddings.shape[0]):
        query_em = query_embeddings[query_index]

        # Get the similarities between query and dataset embeddings
        similarities = turtle_similarities(query_em, dataset_image_embeddings)
        print(similarities[-3])
        print(dataset_paths_list.head(-3))
        # Extract image paths from dataset
        image_locations = []  # image paths
        image_labels = dataset_embedded_image_labels.tolist()
        # Tie aggregated dataset embedding labels to the path of the first image of that identity
        if aggregate_dataset_identities:
            for i in range((len(image_labels))):
                label = image_labels[i]
                for j in range(len(dataset)):
                    if dataset[j][1] == label:
                        image_locations.append(dataset[j][0])
                        break
        else:
            image_locations = [dataset[i][0] for i in range(len(dataset))]

        # Sort the similarities and get the top N results
        top_n = min(top_n, len(similarities))
        sorted_similarities, sorted_locations = sort_em_by_sim(similarities, image_locations)
        sorted_locations = sorted_locations[:top_n]
        sorted_similarities = sorted_similarities[:top_n]
        print(sorted_similarities, sorted_locations)
        print("")
        # Prepare images and filenames to be plotted
        images = []
        images_labels = []
        for label in sorted_locations:
            images.append(label)  # image path
            images_labels.append(os.path.basename(label))  # Get the filename from the image path

        # Find the query image path and plot results
        query_image_index = query_index

        if aggregate_query_identities:
            query_label = query_labels[query_index]
            query_image_index = find_matching_indices(image_query_labels, query_label)[0]
            print(f"Query turtle number: {query_index+1} out of {query_embeddings.shape[0]}")
        else:
            print(f"Query image number: {query_index+1} out of {query_embeddings.shape[0]}")
        query_image = query_dataset[query_image_index]
        plot_recognition_results(query_image, images, images_labels, sorted_similarities)



  bonaire_turtle_id side                                           filename
0            13-054    R  /home/delta/Documents/Turtles/dataset_May15th/...
1            13-054    L  /home/delta/Documents/Turtles/dataset_May15th/...
2            24-191    L  /home/delta/Documents/Turtles/dataset_May15th/...
3            24-191    R  /home/delta/Documents/Turtles/dataset_May15th/...
4            24-191    L  /home/delta/Documents/Turtles/dataset_May15th/...
Dataset embeddings loaded.
Found 3 query images


Processing data: 100%|██████████| 1/1 [00:00<00:00,  1.30it/s, Err Rate=0.000%]


tensor(0.3225)
                                                      0  \
0     /home/delta/Documents/Turtles/dataset_May15th/...   
1     /home/delta/Documents/Turtles/dataset_May15th/...   
2     /home/delta/Documents/Turtles/dataset_May15th/...   
3     /home/delta/Documents/Turtles/dataset_May15th/...   
4     /home/delta/Documents/Turtles/dataset_May15th/...   
...                                                 ...   
8408                                                NaN   
8409                                                NaN   
8410                                                NaN   
8411                                                NaN   
8412                                                NaN   

                                                    0.1  
0                                                   NaN  
1                                                   NaN  
2                                                   NaN  
3                                           

IndexError: list index out of range