<a href="https://colab.research.google.com/github/NickyTan8899/tjy/blob/main/%F0%9F%A6%99_%F0%9F%A7%AC_Baseline_with_WildFusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# IMPORTANT: SOME KAGGLE DATA SOURCES ARE PRIVATE
# RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES.
import kagglehub
kagglehub.login()


VBox(children=(HTML(value='<center> <img\nsrc=https://www.kaggle.com/static/images/site-logo.png\nalt=\'Kaggle…

Kaggle credentials set.
Kaggle credentials successfully validated.


In [2]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

animal_clef_2025_path = kagglehub.competition_download('animal-clef-2025')
hathawaytan_balanced_accuracy_path = kagglehub.notebook_output_download('hathawaytan/balanced-accuracy')
hathawaytan_baseline_with_wildfusion_path = kagglehub.notebook_output_download('hathawaytan/baseline-with-wildfusion')

print('Data source import complete.')


Data source import complete.


# Baseline with WildFusion 🦙 🧬

This notebook presents an improved baseline for individual identification using the **[MegaDescriptor](https://arxiv.org/pdf/2311.09118)** with **[ALIKED](https://arxiv.org/pdf/2304.03608)** methods, combined using [WildFusion](https://arxiv.org/pdf/2408.12934).

## What is WildFusion?  
WildFusion is a **feature fusion method** designed to improve the accuracy and robustness of models in individual recognition tasks. Traditional identification models often rely on a single type of feature representation, which may struggle with variations in lighting, angles, and occlusions. WildFusion overcomes these challenges by combining multiple feature extraction techniques, resulting in a more comprehensive and adaptable approach.  

## Dependencies instalation
For the competition we provide two Python packages for loading and preprocessing of available datasets ([wildlife-datasets](https://github.com/WildlifeDatasets/wildlife-datasets)) and tools / method for animal re-identification ([wildlife-tools](https://github.com/WildlifeDatasets/wildlife-tools)).

In [3]:
!pip install git+https://github.com/WildlifeDatasets/wildlife-datasets@develop
!pip install git+https://github.com/WildlifeDatasets/wildlife-tools

Collecting git+https://github.com/WildlifeDatasets/wildlife-datasets@develop
  Cloning https://github.com/WildlifeDatasets/wildlife-datasets (to revision develop) to /tmp/pip-req-build-obgfdez6
  Running command git clone --filter=blob:none --quiet https://github.com/WildlifeDatasets/wildlife-datasets /tmp/pip-req-build-obgfdez6
  Running command git checkout -b develop --track origin/develop
  Switched to a new branch 'develop'
  Branch 'develop' set up to track remote branch 'develop' from 'origin'.
  Resolved https://github.com/WildlifeDatasets/wildlife-datasets to commit 753d9bf64861c3e17011136b3436bf58bf02317f
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting datasets (from wildlife-datasets==1.0.6)
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets->wildlife-datasets==1.0.6)
  Downloading dill-

In [3]:
import numpy as np
from typing import List, Union

def baks_compute(
        y_true: Union[List, np.ndarray],
        y_pred: Union[List, np.ndarray],
        identity_test_only: Union[List, np.ndarray]
    ) -> float:
    """Computes BAKS (balanced accuracy on known samples).

    Focuses only on samples with known identities (not in identity_test_only).

    Args:
        y_true: True labels
        y_pred: Predicted labels
        identity_test_only: Labels of unknown identities (only in test set)

    Returns:
        Balanced accuracy score for known samples
    """
    # Convert inputs to numpy arrays with object dtype to handle mixed types
    y_true = np.array(y_true, dtype=object)
    y_pred = np.array(y_pred, dtype=object)
    identity_test_only = np.array(identity_test_only, dtype=object)

    # Filter out unknown samples
    mask = ~np.isin(y_true, identity_test_only)
    y_true_known = y_true[mask]
    y_pred_known = y_pred[mask]

    if len(y_true_known) == 0:
        return 0.0

    # Get unique classes in the filtered true labels
    unique_classes = np.unique(y_true_known)
    n_classes = len(unique_classes)

    # Compute per-class accuracy and average
    class_accuracies = []
    for cls in unique_classes:
        cls_mask = (y_true_known == cls)
        if np.sum(cls_mask) > 0:
            cls_acc = np.mean(y_pred_known[cls_mask] == cls)
            class_accuracies.append(cls_acc)

    # Return the balanced accuracy (mean of per-class accuracies)
    return np.mean(class_accuracies) if class_accuracies else 0.0

def baus_compute(
        y_true: Union[List, np.ndarray],
        y_pred: Union[List, np.ndarray],
        identity_test_only: Union[List, np.ndarray],
        new_class: Union[int, str]
    ) -> float:
    """Computes BAUS (balanced accuracy on unknown samples).

    Focuses only on samples with unknown identities (in identity_test_only).

    Args:
        y_true: True labels
        y_pred: Predicted labels
        identity_test_only: Labels of unknown identities (only in test set)
        new_class: Label used for identifying unknown samples

    Returns:
        Balanced accuracy score for unknown samples
    """
    # Convert inputs to numpy arrays with object dtype to handle mixed types
    y_true = np.array(y_true, dtype=object)
    y_pred = np.array(y_pred, dtype=object)
    identity_test_only = np.array(identity_test_only, dtype=object)

    # Filter to include only unknown samples
    mask = np.isin(y_true, identity_test_only)
    y_true_unknown = y_true[mask]
    y_pred_unknown = y_pred[mask]

    if len(y_true_unknown) == 0:
        return 0.0

    # Get unique unknown classes
    unique_unknown_classes = np.unique(y_true_unknown)

    # For each unknown class, check if they were correctly predicted as new_class
    class_accuracies = []
    for cls in unique_unknown_classes:
        cls_mask = (y_true_unknown == cls)
        if np.sum(cls_mask) > 0:
            # For unknown samples, correct prediction is new_class
            cls_acc = np.mean(y_pred_unknown[cls_mask] == new_class)
            class_accuracies.append(cls_acc)

    # Return the balanced accuracy (mean of per-class accuracies)
    return np.mean(class_accuracies) if class_accuracies else 0.0


def compute_geometric_mean(baks, baus):
    return np.sqrt(baks * baus)

## Dependencies import
We load all the required packages and then define the function `create_sample_submission`, which converts provided predictions and a submission file for the competition.

In [66]:
import os
import numpy as np
import pandas as pd
import timm
import torchvision.transforms as T
from wildlife_datasets.datasets import AnimalCLEF2025
from wildlife_tools.features import DeepFeatures
from wildlife_tools.similarity import CosineSimilarity
from wildlife_tools.similarity.wildfusion import SimilarityPipeline, WildFusion
from wildlife_tools.similarity.pairwise.lightglue import MatchLightGlue
from wildlife_tools.similarity.pairwise.loftr import MatchLOFTR
from wildlife_tools.features.local import AlikedExtractor,SuperPointExtractor,SiftExtractor,DiskExtractor
from wildlife_tools.similarity.calibration import IsotonicCalibration,LogisticCalibration
import sys
from wildlife_tools.similarity.pairwise.collectors import CollectCounts, CollectCountsRansac, CollectAll
sys.path.append('/kaggle/input/balanced-accuracy')  # 添加路径
# from metric import score,BAKS,BAUS

def create_sample_submission(dataset_query, predictions, file_name='submission.csv'):
    df = pd.DataFrame({
        'image_id': dataset_query.metadata['image_id'],
        'identity': predictions
    })
    df.to_csv(file_name, index=False)

## Inference WildFusion

Instead of training a classifier, we use out of the shelf pretrained models - [MegaDescriptor](https://huggingface.co/BVRA/MegaDescriptor-L-384) and [ALIKED](https://arxiv.org/pdf/2304.03608) - Keypoint and Descriptor Extraction Network. Both MegaDescriptor and ALIKED are used to extract features from all images.

**Note:** _It is highly recommended to use the GPU acceleration._

We need to specify the `root`, where the data are stored and then two image transformations.
1. The first transform only resizes the images and is used for visualization.
2. The second transform also converts it to torch tensor and is used for operations on neural networks.

In [43]:

root = '/kaggle/input/animal-clef-2025'
root=animal_clef_2025_path
transform_display = T.Compose([
    T.Resize([224, 224]),
])
transform = T.Compose([
    *transform_display.transforms,
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])

transforms_aliked = T.Compose([
    T.Resize([512, 512]),
    T.ToTensor()
])
transforms_sift = T.Compose([
    T.Resize([512, 512]),
    T.ToTensor()
])

In [67]:
# Loading the dataset
dataset = AnimalCLEF2025(root, load_label=True)
dataset_database = dataset.get_subset(dataset.metadata['split'] == 'database')
dataset_query = dataset.get_subset(dataset.metadata['split'] == 'query')
dataset_calibration = AnimalCLEF2025(root, df=dataset_database.metadata[:100], load_label=True)
collector = CollectCountsRansac(ransacReprojThreshold=1.0, maxIters=100)
n_query = len(dataset_query)


In [7]:
n_query

2135

In [70]:
# Loading the models
name = "hf-hub:BVRA/MegaDescriptor-B-224"
model = timm.create_model(name, num_classes=0, pretrained=True)
device = 'cuda'

pipelines = [

    SimilarityPipeline(
        matcher = MatchLightGlue(features='superpoint'),
        extractor = SuperPointExtractor(),
        transform = T.Compose([
            T.Resize([256, 256]),
            T.ToTensor()
        ]),
        calibration = IsotonicCalibration()
    ),

    SimilarityPipeline(
        matcher = MatchLightGlue(features='aliked'),
        extractor = AlikedExtractor(),
        transform = T.Compose([
            T.Resize([256, 256]),
            T.ToTensor()
        ]),
        calibration = IsotonicCalibration()
    ),

    SimilarityPipeline(
        matcher = MatchLightGlue(features='disk'),
        extractor = DiskExtractor(),
        transform = T.Compose([
            T.Resize([256, 256]),
            T.ToTensor()
        ]),
        calibration = IsotonicCalibration()
    ),

    SimilarityPipeline(
        matcher = MatchLightGlue(features='sift'),
        extractor = SiftExtractor(),
        transform = T.Compose([
            T.Resize([256, 256]),
            T.ToTensor()
        ]),
        calibration = IsotonicCalibration()
    ),

    SimilarityPipeline(
        matcher = MatchLOFTR(pretrained='indoor'),
        extractor = None,
        transform = T.Compose([
            T.Resize([256, 256]),
            T.Grayscale(),
            T.ToTensor(),
        ]),
        calibration = IsotonicCalibration()
    ),

    # SimilarityPipeline(
    #     matcher = CosineSimilarity(),
    #     extractor = DeepFeatures(
    #         model = timm.create_model('hf-hub:BVRA/wildlife-mega-L-384', num_classes=0, pretrained=True)
    #     ),
    #     transform = T.Compose([
    #         T.Resize(size=(384, 384)),
    #         T.ToTensor(),
    #         T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    #     ]),
    #     calibration = IsotonicCalibration()
    # ),
]

matcher_mega = SimilarityPipeline(
    matcher = CosineSimilarity(),
    extractor = DeepFeatures(model=model, device=device, batch_size=16),
    transform = transform,
    calibration =IsotonicCalibration()
)

Downloading: "http://cmp.felk.cvut.cz/~mishkdmy/models/loftr_indoor.ckpt" to /root/.cache/torch/hub/checkpoints/loftr_indoor.ckpt
100%|██████████| 44.2M/44.2M [00:02<00:00, 20.4MB/s]


In [71]:
# Calibrating the WildFusion
import torch, gc
gc.collect()
torch.cuda.empty_cache()
wildfusion = WildFusion(calibrated_pipelines = pipelines, priority_pipeline = matcher_mega)
wildfusion.fit_calibration(dataset_calibration, dataset_calibration)

100%|█████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.24it/s]
100%|█████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.27it/s]
100%|███████████████████████████████████████████████████████████████| 79/79 [00:16<00:00,  4.75it/s]
100%|█████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.04it/s]
100%|█████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.60it/s]
100%|███████████████████████████████████████████████████████████████| 79/79 [00:15<00:00,  5.06it/s]
100%|█████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.10it/s]
100%|█████████████████████████████████████████████████████████████| 100/100 [00:02<00:00, 34.75it/s]
100%|███████████████████████████████████████████████████████████████| 79/79 [00:15<00:00,  5.14it/s]
100%|█████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 2

In [72]:
# Compute WildFusion similarity
similarity = wildfusion(dataset_query, dataset_database, B=25)

100%|█████████████████████████████████████████████████████████████| 134/134 [00:51<00:00,  2.62it/s]
100%|█████████████████████████████████████████████████████████████| 818/818 [02:10<00:00,  6.27it/s]
100%|███████████████████████████████████████████████████████████| 2135/2135 [00:53<00:00, 40.21it/s]
100%|█████████████████████████████████████████████████████████| 13074/13074 [02:28<00:00, 87.85it/s]
100%|█████████████████████████████████████████████████████████████| 417/417 [01:28<00:00,  4.69it/s]
100%|███████████████████████████████████████████████████████████| 2135/2135 [00:54<00:00, 39.40it/s]
100%|█████████████████████████████████████████████████████████| 13074/13074 [03:00<00:00, 72.50it/s]
100%|█████████████████████████████████████████████████████████████| 417/417 [01:23<00:00,  4.99it/s]
100%|███████████████████████████████████████████████████████████| 2135/2135 [00:53<00:00, 39.95it/s]
100%|█████████████████████████████████████████████████████████| 13074/13074 [02:34<00:00, 8

In [79]:
pred_idx = similarity.argsort(axis=1)[:,-1]
pred_scores = similarity[range(n_query), pred_idx]
similarity[range(n_query), pred_idx]

array([0.05908, 0.07043, 0.6357 , ..., 0.4414 , 0.07886, 0.813  ],
      dtype=float16)

np.float16(0.0581)

In [80]:
new_individual = 'new_individual'
threshold = 0.6
labels = dataset_database.labels_string
predictions = labels[pred_idx]
predictions[pred_scores < threshold] = new_individual
create_sample_submission(dataset_query, predictions, file_name='/content/submission.csv')

In [None]:
df = dataset.df
df = df[df['split'] == 'query']

In [None]:
unseen_ids = []
all_ids = df['identity'].unique()

for i in all_ids:
    if i in dataset_query.metadata['identity'].tolist():
        if i not in dataset_database.metadata['identity'].tolist():
            unseen_ids.append(i)
all_ids

In [None]:
# val_true_labels = dataset_query.labels_string
# baks_score = baks_compute(val_true_labels, predictions, unseen_ids)
# baus_score = baus_compute(val_true_labels, predictions, unseen_ids, "new_individual")
# geo_mean = compute_geometric_mean(baks_score, baus_score)

# print(f"Balanced Accuracy Known Samples (BAKS): {baks_score:.4f}")
# print(f"Balanced Accuracy Unknown Samples (BAUS): {baus_score:.4f}")
# print(f"Geometric Mean (BAKS & BAUS): {geo_mean:.4f}")