In [27]:
from data.dataset import WildlifeDataset
from torchvision import transforms as T
import pandas as pd

root = '/home/cermavo3/projects/wildlife-experiments/data_256/StripeSpotter'
#root = '/home/cermavo3/projects/wildlife-experiments/data_256/Giraffes'
metadata = pd.read_csv(f'{root}/annotations.csv', index_col=0)



import numpy as np
def create_distance_matrix(sim, idx_train, idx_test):
    '''
    Create distance matrix for each of the train / test pair, given input matrix.
    Input matrix is product of matching algorithm

    Input: upper triangular matrix with shape (n_total, n_total) with zeros on diagonal.
    Output: distance matrix of shape (n_test, n_train)     

    `'''
    sim = sim.copy().astype(np.float32)

    if not np.allclose(sim, np.triu(sim)):
        raise ValueError('Input matrix needs to be upper triangular.')

    if not np.all(np.diag(sim) == 0):
        raise ValueError('Input matrix needs to have zeros in diagonal.')


    np.fill_diagonal(sim, np.inf)
    sim_symetric = np.sum([sim, sim.T], axis=0)
    sim_subset = sim_symetric[:, idx_train][idx_test, :]
    return sim_subset



from features.sift import SIFTFeatures
from similarity.descriptors import MatchDescriptors
from data.split import SplitWildlife
from wildlife_datasets import splits

# Grayscale PIL images
transform = T.Compose([
    T.Resize(size=256),
    T.Grayscale(),
    T.CenterCrop(size=(224, 224)),
])

dataset_database = WildlifeDataset(
    metadata=metadata,
    root=root,
    transform=transform,
    split=SplitWildlife(splits.ClosedSetSplit(0.8, identity_skip='unknown', seed=666), split='train'),
)

dataset_query = WildlifeDataset(
    metadata=metadata,
    root=root,
    transform=transform,
    split=SplitWildlife(splits.ClosedSetSplit(0.8, identity_skip='unknown', seed=666), split='test')
)

dataset_all = WildlifeDataset(
    metadata=metadata,
    root=root,
    transform=transform,
)

In [64]:
from wildlife_datasets import splits


root2 = '/home/cermavo3/projects/wildlife-experiments/experiments/matchers_descriptor/runs'
path2 = f"{root2}/StripeSpotter_sift_1-1_Jul19-09-41-55-1540/split-0/similarity.pickle"


splitter = splits.ClosedSetSplit(0.8, identity_skip='unknown', seed=666)
idx_train, idx_test = splitter.split(metadata)[0]

import os
import pickle

with open(os.path.join(path2), 'rb') as handle:
    similarity_old = pickle.load(handle)
sim_old = similarity_old[0.6]
sim2 = create_distance_matrix(sim_old, idx_train, idx_test)


matcher2 = SIFTMatcher(device='cuda', thresholds=[0.2, 0.4, 0.6])
matcher2.train(dataset_all) # GOOD: This matches old stuff.

Mode: Matching all pairs in query


100%|██████████| 824/824 [00:07<00:00, 105.88it/s]


Total pairs     : 339076
Pairs in chunk  : 339076


100%|██████████| 339076/339076 [00:49<00:00, 6912.35it/s]


In [65]:
sim2

array([[4., 7., 2., ..., 6., 4., 1.],
       [8., 7., 2., ..., 3., 3., 3.],
       [1., 4., 0., ..., 2., 4., 1.],
       ...,
       [1., 0., 0., ..., 3., 1., 3.],
       [2., 2., 0., ..., 1., 2., 1.],
       [1., 1., 3., ..., 1., 5., 5.]], dtype=float32)

In [66]:
matcher = SIFTMatcher(device='cuda', thresholds=[0.2, 0.4, 0.6])
matcher.train(dataset_query=dataset_query, dataset_database=dataset_database)

Mode: Matching query with database


100%|██████████| 164/164 [00:01<00:00, 105.53it/s]
100%|██████████| 656/656 [00:06<00:00, 106.25it/s]


Total pairs     : 107584
Pairs in chunk  : 107584


100%|██████████| 107584/107584 [00:15<00:00, 7007.40it/s]


array([[1., 2., 0., ..., 0., 2., 0.],
       [8., 7., 2., ..., 3., 0., 3.],
       [1., 4., 0., ..., 2., 4., 1.],
       ...,
       [1., 0., 0., ..., 2., 3., 3.],
       [2., 2., 0., ..., 1., 2., 1.],
       [1., 1., 1., ..., 1., 5., 5.]], dtype=float16)

In [75]:
(matcher.similarity[0.6] == similarity[0.6]).all()

True

In [3]:
import cv2
import numpy as np
from tqdm import tqdm
import itertools
import faiss
from similarity.base import Similarity


def get_faiss_index(d, device='cpu'):
    if device == 'cuda':
        resource = faiss.StandardGpuResources()
        config = faiss.GpuIndexFlatConfig()
        config.device = 0
        return faiss.GpuIndexFlatL2(resource, d, config)
    elif device == 'cpu':
        return faiss.IndexFlatL2(d)
    else:
        raise ValueError(f'Invalid device: {device}')


class MatchDescriptors(Similarity):
    def __init__(
        self,
        descriptor_dim: int = 128,
        thresholds: tuple[float] = (0.5, ),
        device: str = 'cpu',
    ):

        self.descriptor_dim = descriptor_dim
        self.thresholds = thresholds
        self.device = device


    def calculate(self, query, database):
        iterator = itertools.product(enumerate(query), enumerate(database))
        iterator_size = len(query)*len(database)
        similarities = {t: np.full((len(query), len(database)), np.nan, dtype=np.float16) for t in self.thresholds}

        index = get_faiss_index(d=self.descriptor_dim, device=self.device)
        for pair in tqdm(iterator, total=iterator_size, mininterval=1, ncols=100):
            (q_idx, q_data), (d_idx, d_data) = pair

            if (q_data is None) or (d_data is None):
                for t in self.thresholds:
                    similarities[t][q_idx, d_idx] = 0

            else:
                index.reset()
                index.add(q_data)
                score, idx = index.search(d_data, k=2)
                with np.errstate(divide='ignore'):
                    ratio = score[:, 0] / score[:, 1]
                for t in self.thresholds:
                    similarities[t][q_idx, d_idx] = np.sum(ratio < t)

        return similarities


In [28]:
# Extract set of SIFT local descriptor for each image.
extractor = SIFTFeatures()

features_query = extractor(dataset_query)
features_database = extractor(dataset_database)


similarity_func = MatchDescriptors(descriptor_dim=128, thresholds=[0.2, 0.4, 0.6, 0.8], device='cuda')
similarity = similarity_func.calculate(
    query=features_query,
    database=features_database,
)

100%|████████████████████████████████████████████████████████████| 164/164 [00:01<00:00, 104.94it/s]
100%|████████████████████████████████████████████████████████████| 656/656 [00:06<00:00, 105.32it/s]
100%|█████████████████████████████████████████████████████| 107584/107584 [00:16<00:00, 6536.85it/s]


In [29]:
# Extract set of SIFT local descriptor for each image.
extractor = SIFTFeatures()

features_query = extractor(dataset_query)
features_database = extractor(dataset_database)


similarity_func = MatchDescriptors(descriptor_dim=128, thresholds=[0.2, 0.4, 0.6, 0.8], device='cuda')
similarity2 = similarity_func.calculate(
    query=features_database,
    database=features_query,
)

100%|████████████████████████████████████████████████████████████| 164/164 [00:01<00:00, 104.85it/s]
100%|████████████████████████████████████████████████████████████| 656/656 [00:06<00:00, 106.00it/s]
100%|█████████████████████████████████████████████████████| 107584/107584 [00:16<00:00, 6556.60it/s]


In [30]:
labels_query = dataset_query.metadata['identity'].values.astype(str)
labels_database = dataset_database.metadata['identity'].values.astype(str)

import torch
scores, idx = torch.tensor(similarity[0.2], dtype=float).topk(k=1, dim=1, largest=True)
pred = labels_database[idx].flatten()

In [35]:
labels_query = dataset_query.metadata['identity'].values.astype(str)
labels_database = dataset_database.metadata['identity'].values.astype(str)

import torch
scores, idx = torch.tensor(similarity2[0.2].T, dtype=float).topk(k=1, dim=1, largest=True)
pred = labels_database[idx].flatten()

In [36]:
hits = (pred == labels_query)
sum(hits) / len(pred)

0.75

In [21]:
hits = (pred == labels_query)
sum(hits) / len(pred)

0.8963414634146342

In [17]:
hits = (pred == labels_query)
sum(hits) / len(pred)

0.7195121951219512

In [13]:
hits = (pred == labels_query)
sum(hits) / len(pred)

0.7439024390243902

In [None]:
def nn_classifier(distance, train_labels):
    scores, idx = torch.tensor(distance).topk(k=1, dim=1, largest=False)
    return train_labels[idx].flatten(), scores.flatten()


def evaluate_closed(distance, train_labels, test_labels):
    prediction, score = nn_classifier(distance, train_labels)
    return {
        'acc': metrics.accuracy(test_labels, prediction),
    }


In [None]:

dataset_query = WildlifeDataset(
    metadata=metadata,
    root=root,
    transform=transform,
    split=SplitWildlife(splitter, split='test')
)


# Cosine similarity between deep features
extractor = DeepFeatures(device='cuda', model=timm.create_model('swin_tiny_patch4_window7_224', num_classes=0, pretrained=True))

similarity_func = CosineSimilarity()

features_query = extractor(dataset_query)
features_database = extractor(dataset_database)

similarity = similarity_func.calculate(
    query=features_query,
    database=features_database,
)

In [17]:
import numpy as np
import pandas as pd
import itertools
import pandas as pd
import torch
import kornia.feature as KF
from data.dataset import WildlifeDataset
import math
import os
from tqdm import tqdm
import cv2
import faiss
import torch
import numpy as np
import fcntl
import time
import errno


def create_distance_matrix(sim, idx_train, idx_test):
    '''
    Create distance matrix for each of the train / test pair, given input matrix.
    Input matrix is product of matching algorithm

    Input: upper triangular matrix with shape (n_total, n_total) with zeros on diagonal.
    Output: distance matrix of shape (n_test, n_train)     

    `'''
    sim = sim.copy().astype(np.float32)

    if not np.allclose(sim, np.triu(sim)):
        raise ValueError('Input matrix needs to be upper triangular.')

    if not np.all(np.diag(sim) == 0):
        raise ValueError('Input matrix needs to have zeros in diagonal.')


    np.fill_diagonal(sim, np.inf)
    sim_symetric = np.sum([sim, sim.T], axis=0)
    sim_subset = sim_symetric[:, idx_train][idx_test, :]
    return -sim_subset


import pickle
def single_file_save(path, similarity_new):
    if not os.path.exists(path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, 'w') as f: # Create empty file to lock on first write.
            pass

    for i in range(60): # Try for one hour
        try:
            with open(path, 'rb+') as file:
                fcntl.flock(file, fcntl.LOCK_EX | fcntl.LOCK_NB)
                if os.path.getsize(path) == 0:
                    similarity = similarity_new
                else:
                    #similarity = np.load(file, allow_pickle='TRUE').item()
                    similarity = pickle.load(file)
                    for key in similarity_new.keys():
                        similarity[key] = np.sum([similarity[key], similarity_new[key]], axis=0)
                    file.seek(0)

                #np.save(file, similarity)
                pickle.dump(similarity, file, pickle.HIGHEST_PROTOCOL)
                fcntl.flock(file, fcntl.LOCK_UN)
                break
        except (OSError, IOError) as e:
            if e.errno != errno.EAGAIN:
                raise
            print('File locked - waiting...')
            time.sleep(60)


def get_faiss_index(d, device='cpu'):
    if device == 'cuda':
        resource = faiss.StandardGpuResources()
        config = faiss.GpuIndexFlatConfig()
        config.device = 0
        return faiss.GpuIndexFlatL2(resource, d, config)
    elif device == 'cpu':
        return faiss.IndexFlatL2(d)
    else:
        raise ValueError(f'Invalid device: {device}')


def batched(iterable, n):
    '''
    Batch data into tuples of length n. The last batch may be shorter.
    Example: batched('ABCDEFG', 3) --> ABC DEF G
    '''
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while batch := tuple(itertools.islice(it, n)):
        yield batch


def prepare_pair_batches(query, database=None, batch_size=128, chunk=1, chunk_total=1):
    '''
    Prepares batches of data pairs given query and optionally the database. 
    
    Returns an iterator that iterates through the batches. Optionally, it can
    be splited to chunks for better memory optimization and parallelization.
    '''
    if database:
        pair_total = len(query)*len(database)
        pair_iterator = itertools.product(enumerate(query), enumerate(database))
    else:
        pair_total = math.comb(len(query), 2)
        pair_iterator = itertools.combinations(enumerate(query), 2)

    batch_total = int(np.ceil(pair_total / batch_size))
    batch_iterator = batched(pair_iterator, batch_size)

    batches = np.array_split(np.arange(batch_total), chunk_total)[chunk-1]
    batch_min, batch_max = batches[0], batches[-1] + 1
    iterator = itertools.islice(batch_iterator, batch_min, batch_max)

    print(f'Total pairs     : {pair_total}')
    print(f'Total batches   : {batch_total}')
    print(f'Batches in chunk: {len(batches)}')
    return iterator, len(batches)


def prepare_pairs(query, database=None, batch_size=128, chunk=1, chunk_total=1):
    '''
    Prepares data pairs given query and optionally the database. 
    
    Returns an iterator that iterates through the batches. Optionally, it can
    be splited to chunks for better memory optimization and parallelization.
    '''
    if database:
        pair_total = len(query)*len(database)
        pair_iterator = itertools.product(enumerate(query), enumerate(database))
    else:
        pair_total = math.comb(len(query), 2)
        pair_iterator = itertools.combinations(enumerate(query), 2)

    pairs = np.array_split(np.arange(pair_total), chunk_total)[chunk-1]
    pair_min, pair_max = pairs[0], pairs[-1] + 1
    iterator = itertools.islice(pair_iterator, pair_min, pair_max)

    print(f'Total pairs     : {pair_total}')
    print(f'Pairs in chunk  : {len(pairs)}')
    return iterator, len(pairs)


def compose_similarity(folders):
    '''
    Create similarity matrix given folders with similarity matrix chunks.
    '''
    data = []
    for folder in folders:
        path = os.path.join(folder, 'similarity.pickle')
        data_chunk = np.load(path, allow_pickle='TRUE').item()
        data.append(data_chunk)

    similarity = {}
    for key in data_chunk.keys():
        similarity[key] = np.sum([d[key] for d in data], axis=0)
    return similarity


class LOFTRMatcher():
    def __init__(
        self,
        device: str ='cuda',
        pretrained: str ='outdoor',
        thresholds: tuple[float] = (0.99, ),
        batch_size: int = 128,
        chunk: int = 1,
        chunk_total: int = 1,
    ):
        self.device = device
        self.matcher = KF.LoFTR(pretrained=pretrained).to(device)
        self.thresholds = thresholds
        self.batch_size = batch_size
        if chunk > chunk_total:
            raise ValueError('Current chunk is larger that chunk total.')
        self.chunk = chunk
        self.chunk_total = chunk_total
        self.similarity = None


    def train(
        self,
        dataset_query: WildlifeDataset,
        dataset_database: WildlifeDataset | None = None,
        **kwargs,
    ):
        if dataset_database:
            print('Matching query with database')
            query = [i[0] for i in dataset_query]
            database = [i[0] for i in dataset_database]
            self.similarity = {t: np.zeros((len(query), len(database)), dtype=np.float16) for t in self.thresholds}
        else:
            print('Matching all pairs in query')
            query = [i[0] for i in dataset_query]
            database = None
            self.similarity = {t: np.zeros((len(query), len(query)), dtype=np.float16) for t in self.thresholds}

        iterator, iterator_size = prepare_pair_batches(
            query = query,
            database = database,
            batch_size = self.batch_size,
            chunk = self.chunk,
            chunk_total = self.chunk_total
        )

        for pair_batch in tqdm(iterator, total=iterator_size, mininterval=1):
            a, b = zip(*pair_batch)
            a_idx, a_data = list(zip(*a))
            b_idx, b_data = list(zip(*b))
            input_dict = {
                "image0": torch.stack(a_data).to(self.device),
                "image1": torch.stack(b_data).to(self.device),
            }
            with torch.inference_mode():
                correspondences = self.matcher(input_dict)

            batch_idx = correspondences['batch_indexes'].cpu().numpy()
            confidence = correspondences['confidence'].cpu().numpy()
            for t in self.thresholds:
                series = pd.Series(confidence > t)
                for j, group in series.groupby(batch_idx):
                    self.similarity[t][a_idx[j], b_idx[j]] = group.sum()


    def save(self, folder, name='similarity.pickle', **kwargs):
        if self.similarity:
            with open(os.path.join(folder, name), 'wb') as handle:
                pickle.dump(self.similarity, handle, protocol=pickle.HIGHEST_PROTOCOL)
            #np.save(os.path.join(folder, name), self.similarity)

    def load(self, path):
        with open(path, 'rb') as handle:
            self.similarity = pickle.load(handle)
        #self.similarity = np.load(path, allow_pickle='TRUE').item()


class DescriptorMatcher():
    def __init__(
        self,
        descriptor_function = None,
        descriptor_dim: int = 128,
        max_keypoints: int | None = None,
        thresholds: tuple[float] = (0.5, ),
        device: str = 'cpu',
        chunk: int = 1,
        chunk_total: int = 1,
        joint_save: str | None = None,
    ):
    
        self.descriptor_function = descriptor_function
        self.descriptor_dim = descriptor_dim
        self.max_keypoints = max_keypoints
        self.thresholds = thresholds
        self.device = device
        if chunk > chunk_total:
            raise ValueError('Current chunk is larger that chunk total.')
        self.chunk = chunk
        self.chunk_total = chunk_total
        self.joint_save = joint_save
        self.similarity = None

    def get_descriptors(self, dataset):
        if self.descriptor_function:
            return self.descriptor_function(dataset)
        else:
            raise ValueError('No descriptor function provided.')

    def train(
        self,
        dataset_query: WildlifeDataset,
        dataset_database: WildlifeDataset | None = None,
        **kwargs,
    ):
        if dataset_database:
            print('Mode: Matching query with database')
            query = self.get_descriptors(dataset_query)
            database = self.get_descriptors(dataset_database)
            self.similarity = {t: np.zeros((len(query), len(database)), dtype=np.float16) for t in self.thresholds}
        else:
            print('Mode: Matching all pairs in query')
            query = self.get_descriptors(dataset_query)
            database = None
            self.similarity = {t: np.zeros((len(query), len(query)), dtype=np.float16) for t in self.thresholds}

        iterator, iterator_size = prepare_pairs(
            query = query,
            database = database,
            chunk = self.chunk,
            chunk_total = self.chunk_total
        )

        index = get_faiss_index(d=self.descriptor_dim, device=self.device)
        for (a_idx, a_data), (b_idx, b_data) in tqdm(iterator, total=iterator_size, mininterval=1):
            if (a_data is None) or (b_data is None):
                continue
            else:
                index.reset()
                index.add(a_data)
                score, idx = index.search(b_data, k=2)
                with np.errstate(divide='ignore'):
                    ratio = score[:, 0] / score[:, 1]
                for t in self.thresholds:
                    self.similarity[t][a_idx, b_idx] = np.sum(ratio < t)


    def save(self, folder, name='similarity.pickle', **kwargs):
        if not self.similarity:
            return

        if self.joint_save:
            single_file_save(self.joint_save, self.similarity)
        else:
            with open(os.path.join(folder, name), 'wb') as handle:
                pickle.dump(self.similarity, handle, protocol=pickle.HIGHEST_PROTOCOL)


    def load(self, path):
        with open(path, 'rb') as handle:
            self.similarity = pickle.load(handle)

class SIFTMatcher(DescriptorMatcher):
    def __init__(self, *args, **kwargs):

        super().__init__(*args, **kwargs)
        self.descriptor_dim = 128

    def get_descriptors(self, dataset):
        if self.max_keypoints:
            sift = cv2.SIFT_create(nfeatures=self.max_keypoints)
        else:
            sift = cv2.SIFT_create()

        descriptors = []
        for img, y in tqdm(dataset, mininterval=1):
            keypoint, d = sift.detectAndCompute(np.array(img), None)
            if len(keypoint) <= 1:
                descriptors.append(None)
            else:
                descriptors.append(d)
        return descriptors
