In [None]:
!pip install faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.4 kB)
Downloading faiss_cpu-1.10.0-cp311-cp311-manylinux_2_28_x86_64.whl (30.7 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m30.7/30.7 MB[0m [31m72.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.10.0


In [None]:
!git clone https://github.com/murilogustineli/pytorch-plantclef.git

Cloning into 'pytorch-plantclef'...
remote: Enumerating objects: 399, done.[K
remote: Counting objects: 100% (165/165), done.[K
remote: Compressing objects: 100% (97/97), done.[K
remote: Total 399 (delta 109), reused 80 (delta 68), pack-reused 234 (from 1)[K
Receiving objects: 100% (399/399), 103.53 MiB | 38.16 MiB/s, done.
Resolving deltas: 100% (217/217), done.


In [None]:
import pandas as pd
import torch
import numpy as np
import faiss
from collections import Counter
import os
import csv
import sys
sys.path.append('/content/pytorch-plantclef/plantclef')
from config import get_device, create_submission_csv

TRAIN_EMBEDDINGS_PATH = "/content/drive/MyDrive/01_plantclef_datasets/embeddings/train_embeddings_24k.parquet"
TEST_EMBEDDINGS_PATH = "/content/drive/MyDrive/01_plantclef_datasets/embeddings/explode_test_embeddings_3x3"

In [None]:
train_df = pd.read_parquet(TRAIN_EMBEDDINGS_PATH)
test = pd.read_parquet(TEST_EMBEDDINGS_PATH)
print(train_df.shape, test.shape)

(25791, 5) (18945, 4)


In [None]:
# Remoção do aninhamento np.array([np.array([])]) para o np.stack ser realizado corretamente
train_df["embeddings"] = train_df["embeddings"].apply(lambda x: x[0] if isinstance(x[0], np.ndarray) else x)

### faiss class

In [None]:
class FaissClassifier:
    def __init__(self, train_df: pd.DataFrame):
        """
        :param train_df: DataFrame with columns ["species_id", "embeddings"]
        """
        self.device = get_device()
        self.index, self.idx2cls = self.build_index(train_df)

    def build_index(self, train_df):
        """Builds the FAISS index from the training data."""

        # store class labels
        idx2cls = train_df["image_name"].values
        # convert embeddings to tensor
        embs_array = np.array(train_df["embeddings"].tolist(), dtype=np.float32)
        embs = torch.tensor(embs_array, device=self.device)
        # normalize embeddings for cosine similarity
        embs = torch.nn.functional.normalize(embs, p=2, dim=1)
        # create FAISS index
        index = faiss.IndexFlatIP(embs.shape[1])  # inner product (dot product)
        index.add(embs.cpu().numpy())  # FAISS expects numpy arrays
        return index, idx2cls

    def make_prediction(self, query_embeddings: torch.Tensor, k=1):
        """
        Predicts the class of given embeddings using nearest neighbor search.
        :param query_embeddings: tensor of shape (N, D) where N is the number of embeddings and D is the embedding dimension
        :param k: number of nearest neighbors to return
        :return: predictions, similarities
        """

        # normalize embeddings for cosine similarity
        query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
        # perform search
        similarities, indices = self.index.search(query_embeddings.cpu().numpy(), k=k)
        predictions = self.idx2cls[indices]
        return predictions, similarities

In [None]:
def create_classification_dataframe(
    train_df: pd.DataFrame,
    test_df: pd.DataFrame,
    predictions: np.array,
    similarities: np.array,
):
    """
    Creates a classification DataFrame with Faiss predictions, similarities, and resolved species IDs.

    :param train_df: Train DataFrame with image_name to species_id mapping
    :param test_df: Test DataFrame (contains image_name, data, embeddings, etc.)
    :param predictions: np.array of shape (N, K) with predicted image names
    :param similarities: np.array of shape (N, K) with similarity scores
    :return: DataFrame with columns: predictions, similarities, species_ids
    """
    cls_test_df = test_df.copy()
    cls_test_df["predictions"] = predictions.tolist()
    cls_test_df["similarities"] = similarities.tolist()
    # create lookup dictionary
    image_to_species = dict(zip(train_df["image_name"], train_df["species_id"]))
    # map preds to species_id
    species_ids = []
    for row in cls_test_df["predictions"]:
        row_species = [image_to_species.get(img_name, None) for img_name in row]
        species_ids.append(row_species)
    # add to DataFrame
    cls_test_df["pred_species_ids"] = species_ids
    return cls_test_df

In [None]:

def create_submission_csv(
    faiss_df: pd.DataFrame,
    output_path: str = "../data/submission/submission.csv",
    species_col: str = "pred_species_ids",
    save_csv: bool = False,
):
    """
    Aggregates FAISS-predicted species IDs across 3x3 tiles and writes a submission CSV file.

    :param faiss_df: DataFrame with FAISS predictions, including a column `species_ids` or `pred_species_ids`
    :param output_path: Path to write the CSV file
    :param species_col: Column containing predicted species IDs for each tile
    """

    # group by image_name (each quadrat/image), collect all tile-level species IDs
    grouped = faiss_df.groupby("quadrat_id")[species_col].apply(list)
    records = []
    for image_name, species_id_lists in grouped.items():
        # flatten species_id lists from all tiles
        flat_ids = [
            sid for sublist in species_id_lists for sid in sublist if sid is not None
        ]
        # count and sort by frequency (optional for top species prioritization)
        counted = Counter(flat_ids)
        sorted_ids = [species_id for species_id, _ in counted.most_common()]
        # deduplicate + keep sorted by frequency
        unique_sorted_ids = list(dict.fromkeys(sorted_ids))
        # format as double-bracketed string
        species_ids_str = f"[{', '.join(str(sid) for sid in unique_sorted_ids)}]"
        records.append({"quadrat_id": image_name, "species_ids": species_ids_str})

    # build final DataFrame and write to CSV
    df_run = pd.DataFrame(records)
    if save_csv:
        output_dir = os.path.dirname(output_path)
        os.makedirs(output_dir, exist_ok=True)
        df_run.to_csv(output_path, sep=",", index=False, quoting=csv.QUOTE_ALL)
        print(f"Submission file saved to: {output_path}")
    return df_run

### Similarity Search Classification

In [None]:
classifier = FaissClassifier(train_df=train_df)

embs_array = np.array(test["embeddings"].tolist(), dtype=np.float32)
query_embs = torch.tensor(embs_array, device=get_device())
preds, similarities = classifier.make_prediction(query_embs, k=5)

In [None]:
faiss_classification_df = create_classification_dataframe(train_df,
                                test,
                                preds,
                                similarities)


In [None]:
faiss_classification_df

Unnamed: 0,quadrat_id,embeddings,logits,tile,predictions,similarities,pred_species_ids
0,CBN-PdlC-E3-20130723,"[1.7311939001083374, 1.790383219718933, 0.1066...","{'1355868': None, '1355869': None, '1355870': ...",0,"[e6d00593fc2dc5ebf2f6139213a8ff300875b449.jpg,...","[0.3274020552635193, 0.2963815927505493, 0.265...","[1390910, 1391649, 1361581, 1397431, 1721729]"
1,CBN-PdlC-E3-20130723,"[1.2909488677978516, 2.2347311973571777, -0.95...","{'1355868': None, '1355869': None, '1355870': ...",1,"[8e61a9f80be072ce42c5566be99e0fe509a3d0c9.jpg,...","[0.37576591968536377, 0.3643946051597595, 0.32...","[1391649, 1395807, 1359649, 1394598, 1395935]"
2,CBN-PdlC-E3-20130723,"[1.7716916799545288, 1.0189008712768555, -2.09...","{'1355868': None, '1355869': None, '1355870': ...",2,"[8e61a9f80be072ce42c5566be99e0fe509a3d0c9.jpg,...","[0.40316492319107056, 0.3029630780220032, 0.29...","[1391649, 1393906, 1743466, 1394598, 1395807]"
3,CBN-PdlC-E3-20130723,"[1.507507085800171, 2.4123475551605225, -0.713...","{'1355868': None, '1355869': None, '1355870': ...",3,"[e6d00593fc2dc5ebf2f6139213a8ff300875b449.jpg,...","[0.3446432054042816, 0.3070175051689148, 0.255...","[1390910, 1391649, 1390801, 1359750, 1361581]"
4,CBN-PdlC-E3-20130723,"[1.8004498481750488, 1.6963614225387573, -0.99...","{'1355868': None, '1355869': None, '1355870': ...",4,"[c32fd73f0f3156f519b2b2082bb79598a3bee665.jpg,...","[0.3922504782676697, 0.38106292486190796, 0.37...","[1394597, 1394598, 1394598, 1391649, 1394603]"
...,...,...,...,...,...,...,...
18940,GUARDEN-CBNMed-44-7-12-03-20240629,"[1.123538851737976, 1.2354389429092407, 1.1205...","{'1355868': None, '1355869': None, '1355870': ...",4,"[6813824c89b2b176b8201bd7841d2d4038f71906.jpg,...","[0.42272335290908813, 0.3474133014678955, 0.33...","[1392059, 1393933, 1391989, 1361367, 1361347]"
18941,GUARDEN-CBNMed-44-7-12-03-20240629,"[0.3686266839504242, 1.9914993047714233, -1.08...","{'1355868': None, '1355869': None, '1355870': ...",5,"[6813824c89b2b176b8201bd7841d2d4038f71906.jpg,...","[0.5283453464508057, 0.37342220544815063, 0.37...","[1392059, 1396806, 1396806, 1396806, 1393660]"
18942,GUARDEN-CBNMed-44-7-12-03-20240629,"[-0.7237932682037354, 1.4672207832336426, -0.0...","{'1355868': None, '1355869': None, '1355870': ...",6,"[6813824c89b2b176b8201bd7841d2d4038f71906.jpg,...","[0.45889079570770264, 0.33152082562446594, 0.3...","[1392059, 1393933, 1393660, 1396806, 1360904]"
18943,GUARDEN-CBNMed-44-7-12-03-20240629,"[0.5014994144439697, 1.088453769683838, 0.0358...","{'1355868': None, '1355869': None, '1355870': ...",7,"[6813824c89b2b176b8201bd7841d2d4038f71906.jpg,...","[0.3873826265335083, 0.34275487065315247, 0.32...","[1392059, 1393933, 1396806, 1396806, 1361347]"


In [None]:
create_submission_csv(faiss_classification_df,
                      output_path="/content/submission.csv",
                      species_col='pred_species_ids',
                      save_csv=True)

Submission file saved to: /content/submission.csv


Unnamed: 0,quadrat_id,species_ids
0,2024-CEV3-20240602,"[1395100, 1360203, 1393660, 1360208, 1394359, ..."
1,CBN-PdlC-A1-20130807,"[1391649, 1356422, 1395807, 1397535, 1396027, ..."
2,CBN-PdlC-A1-20130903,"[1356422, 1394542, 1395807, 1397535, 1647128, ..."
3,CBN-PdlC-A1-20140721,"[1392535, 1391649, 1394911, 1356013, 1356422, ..."
4,CBN-PdlC-A1-20140811,"[1356422, 1395807, 1356424, 1356013, 1549192, ..."
...,...,...
2100,RNNB-8-5-20240118,"[1361703, 1363852, 1356333, 1393855, 1356044, ..."
2101,RNNB-8-6-20240118,"[1361703, 1361024, 1356598, 1363149, 1356044, ..."
2102,RNNB-8-7-20240118,"[1359344, 1356448, 1356598, 1393855, 1361703, ..."
2103,RNNB-8-8-20240118,"[1396168, 1391480, 1359344, 1363852, 1359664, ..."
