In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.insert(0, '..')

import torchvision.transforms as T
import numpy as np
import pandas as pd

from wildlife_datasets import datasets, loader
from wildlife_tools.data import WildlifeDataset
from timm import create_model

In [None]:
import copy
from collections.abc import Iterable

class WD(WildlifeDataset):
    def plot_grid(self, **kwargs):
        self_copy = copy.deepcopy(self)
        self_copy.transform = None
        self_copy.load_label = False
        loader = lambda k: self_copy.__getitem__(k)
        rotate = kwargs.pop('rotate', False)
        return datasets.DatasetFactory(self.root, self.metadata).plot_grid(rotate=rotate, loader=loader, **kwargs)

    def plot_predictions(self, y_true, y_pred, **kwargs):
        if not isinstance(y_true, Iterable) and np.array(y_pred).ndim == 1:
            y_true = [y_true]
            y_pred = [y_pred]
        if len(y_true) > 1:
            header_cols = ["Query", ""] + [f"Match {i+1}" for i in range(len(y_pred[0]))]
        else:
            identity = self.metadata['identity'].to_numpy()
            header_cols = [identity[y_true[0]], ""] + [identity[y_p] for y_p in y_pred[0]]
        n_cols = len(header_cols)
        idx = []
        for y_t, y_p in zip(y_true, y_pred):
            idx.append([y_t, -1] + list(y_p))
        n_rows = kwargs.pop('n_rows', min(len(y_true), 5))
        return self.plot_grid(idx=idx, n_rows=n_rows, n_cols=n_cols, header_cols=header_cols, **kwargs)

In [None]:
from wildlife_tools.similarity import CosineSimilarity
import matplotlib.pyplot as plt

# TODO: come up with a normal name
def compute_predictions_disjoint(features, k=4, batch_size=1000):
    # TODO: add check that features is a square matrix
    n_query = len(features)
    n_chunks = int(np.ceil(n_query / batch_size))
    chunks = np.array_split(range(n_query), n_chunks)

    matcher = CosineSimilarity()
    idx_pred = np.zeros((n_query, k), dtype=np.int32)
    for chunk in chunks:
        similarity = matcher(query=features[chunk], database=features)['cosine']
        idx_x = np.arange(len(chunk))
        idx_y = np.arange(chunk[0], chunk[0]+len(chunk))
        similarity[idx_x, idx_y] = -1        
        idx_pred[chunk,:] = (-similarity).argsort(axis=-1)[:, :k]
    idx_true = np.array(range(n_query))    
    return idx_true, idx_pred    

def find_wrong_labels(dataset, idx_true, idx_pred, k=4):
    idx_true_wrong = []
    idx_pred_wrong = []
    for i, js in zip(idx_true, idx_pred):
        y_true = dataset.metadata['identity'].iloc[i]
        y_pred = dataset.metadata['identity'].iloc[js]
        pred_counts = pd.Series(y_pred).value_counts()
        for j in range(len(pred_counts)):
            if pred_counts.iloc[j] < 3:
                break
            if pred_counts.index[j] != y_true:
                idx_true_wrong.append(i)
                idx_pred_wrong.append(js[:min(k,len(js))])
    return idx_true_wrong, idx_pred_wrong

def verify_wrong_labels(load_image, keypoint_extractor, keypoint_matcher, image_matcher, dataset, idx_true, idx_pred):
    for i, js in zip(idx_true, idx_pred):
        img1 = load_image(i)
        for j in js:
            if dataset.metadata['identity'].iloc[i] != dataset.metadata['identity'].iloc[j]:
                img2 = load_image(j)

                kpt1, desc1 = keypoint_extractor.generate_keypoints(img1)
                kpt2, desc2 = keypoint_extractor.generate_keypoints(img2)

                matches = keypoint_matcher.compute_match(kpt1, desc1, kpt2, desc2)
                matched = image_matcher.compute_matches_info(kpt1, kpt2, matches)[0]
                if matched:
                    fig = dataset.plot_predictions(i, js)
                    plt.show(fig)
                    plt.close(fig)
                    break

In [None]:
model = create_model("hf-hub:BVRA/wildlife-mega-L-384", pretrained=True)
transform = T.Compose([
    T.Resize(size=(384, 384)),
    T.ToTensor(),
    T.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

In [None]:
sys.path.append('../../Turtles_Identification/src')

from image_loader import Loader
from keypoint_extractor import SIFT
from keypoint_matcher import L1Matcher
from image_matcher import AlignMatcher

keypoint_extractor = SIFT()
keypoint_matcher = L1Matcher()
image_matcher = AlignMatcher(n_keypoints=10, method=1)

In [None]:
import os 
from collections.abc import Iterable

for dataset_name in datasets.names_all:
    print(dataset_name)
    metadata = loader.load_dataset(dataset_name, '../data', '../data/_dataframes/')
    if metadata.__class__.__name__ in ['Drosophila']:
        print('Skipping dataset')
        continue
    print('Loading dataset')
    if 'bbox' in metadata.df.columns:
        for i_row, df_row in metadata.df.iterrows():
            if not isinstance(df_row['bbox'], Iterable):
                img = datasets.utils.get_image(os.path.join(metadata.root, df_row['path']))
                metadata.df.at[i_row, 'bbox'] = [0, 0, img.size[0], img.size[1]]
        dataset = WD(metadata.df, metadata.root, img_load="bbox", transform=transform)    
    else:
        dataset = WD(metadata.df, metadata.root, img_load="crop_black", transform=transform)
    features = np.load('../data/_features/features_' + dataset_name.__name__ + '.npy')
    
    if 'bbox' in metadata.df.columns:
        dataset_loader = WD(metadata.df, metadata.root, img_load="bbox", transform=T.Resize(size=(256, 256))) 
    else:
        dataset_loader = WD(metadata.df, metadata.root, img_load="crop_black", transform=T.Resize(size=(256, 256)))
    def load_image(k):
        return np.array(dataset_loader[k][0])

    print('Computing predictions')
    idx_true, idx_pred = compute_predictions_disjoint(features)    
    print('Finding potenially wrong labels')
    idx_true, idx_pred = find_wrong_labels(dataset, idx_true, idx_pred)
    print('Found potential %d wrong labels' % len(idx_true))
    #dataset.plot_predictions(idx_true, idx_pred)
    print('Verifying predictions')
    verify_wrong_labels(load_image, keypoint_extractor, keypoint_matcher, image_matcher, dataset, idx_true, idx_pred)