In [None]:
import os
from tqdm import tqdm

import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import faiss

In [None]:
from deep_cluster.embeddings import calculate_embeddings_exp_1
from deep_cluster.embeddings import calculate_embeddings_exp_2
from deep_cluster.embeddings import calculate_embeddings_exp_3
from deep_cluster.embeddings import calculate_embeddings_exp_4
from deep_cluster.embeddings import calculate_embeddings_exp_5
from deep_cluster.embeddings import calculate_embeddings_exp_6
from deep_cluster.embeddings import calculate_embeddings_exp_7
from deep_cluster.embeddings import calculate_embeddings_exp_8
from deep_cluster.embeddings import calculate_embeddings_exp_9
from deep_cluster.embeddings import calculate_embeddings_exp_10

# Data Prep

In [None]:
labeled_data_path = "../data/Madeline_Data/MadelineIs_Modified.csv"
unlabeled_data_path = "../data/Unlabeled_AUV_Data/Unlabeled_AUV_Data.csv"

In [None]:
# Get the labeled data
labeled_df = pd.read_csv(labeled_data_path)
labeled_df = labeled_df[['Image_Path', 'Sclass', 'Ssubclass', 'Sgroup']]

# Drop any rows with NA values
labeled_df = labeled_df.dropna(subset=['Image_Path', 'Sclass', 'Ssubclass', 'Sgroup'])

In [None]:
# Get the unlabeled data, subset, conform
unlabeled_df = pd.read_csv(unlabeled_data_path)
unlabeled_df['Image_Path'] = unlabeled_df['image_path']
unlabeled_df = unlabeled_df[['Image_Path']]

# No labels, set as Unlabeled
unlabeled_df['Sclass'] = "Unlabeled"
unlabeled_df['Ssubclass'] = "Unlabeled"
unlabeled_df['Sgroup'] = "Unlabeled"

In [None]:
# Combine the labeled and unlabeled data
all_data_df = pd.concat([labeled_df, unlabeled_df], axis=0)
all_data_df.reset_index(drop=True, inplace=True)

In [None]:
print(f"Total samples: {len(all_data_df)}")
print(f"Labeled: {len(labeled_df)}")
print(f"Unlabeled: {len(unlabeled_df)}")

In [None]:
# ---> Making modifications to the dataframe Image Paths (after moving locally) <---
all_data_df['Image_Path'] = all_data_df['Image_Path'].str.replace('name_tbd', 'JordanP\\name_tbd')

## Calculate Embeddings

To build an image-to-image search engine, we need to store:

1. An embedding representation of each image in our dataset, for use in querying for related images, and;
2. An index that maps the index of an embedding in the `faiss` data store to file names (faiss associates values with indices, so we need a way to map these indices back to file names).

In [None]:
def get_experiment_embedding_methods():
    """Return a dictionary of all embedding methods to test."""
    return {
        "raw_pixels": calculate_embeddings_exp_1,
        "normalized_pixels": calculate_embeddings_exp_2,
        "pca_pixels": calculate_embeddings_exp_3,
        "rgb_features": calculate_embeddings_exp_4,
        
        "hog_features": calculate_embeddings_exp_5,
        "lbp_features": calculate_embeddings_exp_6,
        "sift_features": calculate_embeddings_exp_7,
        
        "yolov8n_features": calculate_embeddings_exp_8,
        "yolov8x_features": calculate_embeddings_exp_9,
        "dino_features": calculate_embeddings_exp_10,
    }
    
embedding_methods = get_experiment_embedding_methods()

In [None]:
image_column = 'Image_Path'
imgsz = 112

In [None]:
embeddings = embedding_methods['yolov8n_features'](all_data_df, image_column=image_column, imgsz=imgsz)

## Prepare to Search the Index

Below, we define a function that retrieves `k` images with embeddings most similar to the embedding of an input image from our vector index. In other words, this function will search for related images in our index and return their positions in the index.

In [None]:
def create_indices(df: pd.DataFrame, embeddings: np.array) -> faiss.IndexFlatL2:
    """
    Create an index that contains all of the images in the specified list of files.
    """
    # Create a FAISS index
    indices = faiss.IndexFlatL2(embeddings.shape[1])  # L2 distance index

    # Dictionary to hold the image paths and their corresponding embeddings
    image_to_embeddings = {}
    
    # Loop through the DataFrame and add each image / embedding to the index
    for idx, (i, r) in enumerate(tqdm(df.iterrows(), total=len(df))):

      image_to_embeddings[r['Image_Path']] = embeddings[idx]
      indices.add(embeddings[idx].reshape(1, -1))

    # Save the index to a file
    faiss.write_index(indices, "data.bin")

    return indices, image_to_embeddings

In [None]:
indices, image_to_embeddings = create_indices(all_data_df, embeddings)

## Search the Index

The code below takes an input image (`search_file`), calculates an embedding for the image, and uses that embedding to find related images.

We display the `k` top results (default 3) in the notebook.

In [None]:
def search_index(indices: faiss.IndexFlatL2, embeddings: np.array, k: int = 3) -> list:
    """
    Search the index for the images that are most similar to the provided image.
    """
    D, I = indices.search(embeddings.reshape(1, -1), k)

    return I[0]

In [None]:
# Pick based on label, class name
label_columns = ["Sgroup", "Ssubclass", "Sclass"]

for label_column in label_columns:
    print(f"\n {'=' * 5} Label column: {label_column} {'=' * 5}")
    print(all_data_df[label_column].value_counts())

In [None]:
# Define the search image
label_column = "Sgroup"
label = "Gravel Mixes"

# Obtain a random sample of the specified label
sample = all_data_df[all_data_df[label_column] == label].sample(1)

# Get the image path
sample_path = sample['Image_Path'].values[0]

# Get the corresponding embedding
sample_embedding = image_to_embeddings[sample_path]

# Read the (resized) image
sample_image = cv2.resize(cv2.imread(sample_path), (256, 256)) 

plt.title(f"{os.path.basename(sample_path)} - {label_column} - {label}")
plt.imshow(sample_image)

In [None]:
k = 10

# Search the index for similar images
similar_indices = search_index(indices, sample_embedding, k)

In [None]:
similar_indices

# Plot the Results

Plot the images the correspond to the k nearest to the search image, in feature space.

In [None]:
def display_similar_images(indices: np.ndarray, df: pd.DataFrame, label_column: str, imgsz: int = 256, num_cols: int = 2):
    """
    Displays similar images based on the provided indices.

    Args:
        indices (np.ndarray): Array of indices representing similar images.
        df (pd.DataFrame): DataFrame containing image paths and labels.
        label_column (str): The column name in the DataFrame for the image label.
        imgsz (int, optional): The size to resize the images to. Defaults to 256.
        num_cols (int, optional): The number of columns in the display grid. Defaults to 2.
    """
    num_images = len(indices)
    num_rows = (num_images + num_cols - 1) // num_cols  # Calculate the number of rows needed

    fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 6 * num_rows))  # Adjust figure size for better visualization

    # If axes is a single Axes object, wrap it in a list to make it iterable
    if num_rows == 1 and num_cols == 1:
        axes = [[axes]]
    elif num_rows == 1:
        axes = [axes]

    for i, index in enumerate(indices):
        # Calculate the row and column index for the subplot
        row_idx = i // num_cols
        col_idx = i % num_cols
        
        # Get the image path of the similar image
        image_path = df.iloc[index]['Image_Path']
        
        # Read and resize the image
        try:
            image = cv2.resize(cv2.imread(image_path), (imgsz, imgsz))
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)  # Convert BGR to RGB
            
            # Display the image
            ax = axes[row_idx][col_idx]
            ax.imshow(image)
            ax.set_title(f"{os.path.basename(image_path)} - {label_column}: {df.iloc[index][label_column]}")
            ax.axis('off')  # Hide the axis
        except Exception as e:
            print(f"Error processing image {image_path}: {e}")
            continue

    # If the number of images is odd, remove the last empty subplot
    if num_images % num_cols != 0:
        axes[-1][-1].axis('off')

    plt.tight_layout()  # Adjust layout to prevent overlapping titles/labels
    plt.show()

In [None]:
display_similar_images(similar_indices, all_data_df, label_column, imgsz=imgsz, num_cols=2)