## Similarity

In [11]:
import pandas as pd
from data.dataset import WildlifeDataset
from torchvision import transforms as T
from features.utils import save_features, load_features

from collections import defaultdict
from tqdm import tqdm

import faiss
import itertools
import numpy as np

root = '/home/cermavo3/projects/wildlife-experiments/data_256/StripeSpotter'
metadata = pd.read_csv(f'{root}/annotations.csv')

## Deep features - Cosine similarity

In [2]:
#from sklearn.metrics.pairwise import cosine_similarity
from matching.utils import cosine_similarity

deep_features = load_features('temp/stripespotter_deep_features.pickle')
deep_sim = cosine_similarity(deep_features, deep_features)

## Number of correspondences

In [9]:
def batched(iterable, n):
    '''
    Batch data into tuples of length n. The last batch may be shorter.
    Example: batched('ABCDEFG', 3) --> ABC DEF G
    '''
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while batch := tuple(itertools.islice(it, n)):
        yield batch


def get_query_iterator(query_idx, query, database):
    ''' Generator of query-database pairs for given query data point. '''
    for database_idx, database_data in enumerate(database):
        yield (query_idx, query[query_idx]), (database_idx, database_data)


def create_iterators(query, database, batch_size = None, chunk=1, chunk_total=1):
    '''
    Given query and database, create list of pair iterators one for each item in query.
    If batch size is given, add batch iterator before the pair iterator.
    If chunks are given, split query in chunks.
    '''
    subset = np.array_split(np.arange(len(query)), chunk_total)[chunk-1]

    iterators = []
    for i in subset:
        iterator = get_query_iterator(i, query, database)
        if batch_size:
            iterator = batched(iterator, batch_size)
        iterators.append((i, iterator))
    return iterators


def get_faiss_index(d, device='cpu'):
    if device == 'cuda':
        resource = faiss.StandardGpuResources()
        config = faiss.GpuIndexFlatConfig()
        config.device = 0
        return faiss.GpuIndexFlatL2(resource, d, config)
    elif device == 'cpu':
        return faiss.IndexFlatL2(d)
    else:
        raise ValueError(f'Invalid device: {device}')

In [171]:
class DescriptorMatcher():
    def __init__(
        self,
        descriptor_dim: int = 128,
        thresholds: tuple[float] = (0.5, ),
        device: str = 'cpu',
        chunk: int = 1,
        chunk_total: int = 1,
        joint_save: str | None = None,
    ):

        self.descriptor_dim = descriptor_dim
        self.thresholds = thresholds
        self.device = device
        if chunk > chunk_total:
            raise ValueError('Current chunk is larger that chunk total.')
        self.chunk = chunk
        self.chunk_total = chunk_total
        self.joint_save = joint_save
        self.similarity = None

self = DescriptorMatcher()
features = load_features('temp/stripespotter_sp_features.pickle')
database = load_features('temp/stripespotter_sp_features.pickle')[:500]
query = load_features('temp/stripespotter_sp_features.pickle')[:50]

iterators = create_iterators(
    query = query,
    database = query,
    chunk = self.chunk,
    chunk_total = self.chunk_total
)

idx, iterator = iterators[0]
print(idx)


#similarity = {t: defaultdict(lambda: np.zeros(len(database), dtype=np.float16)) for t in self.thresholds}
similarity = {t: {} for t in self.thresholds}
index = get_faiss_index(d=256, device=self.device)
for query_idx, iterator in tqdm(iterators):
    for t in self.thresholds:
        similarity[t][query_idx] = np.zeros(len(database), dtype=np.float16)

    for (_, query_data), (database_idx, database_data) in iterator:
        
        # If no decriptors were found
        if (query_data is None) or (database_data is None):
            print('xxxz')
            continue

        index.reset()
        index.add(database_data)
        score, idx = index.search(query_data, k=2)
        with np.errstate(divide='ignore'):
            ratio = score[:, 0] / score[:, 1]

        for t in self.thresholds:
            similarity[t][query_idx][database_idx] = np.sum(ratio < t)
    break

0


  0%|                                                                                                                                 | 0/50 [00:00<?, ?it/s]


In [172]:
similarity

{0.5: {0: array([42.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.,  0.,
          1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  0.,  0.,  0.,  

In [110]:
from tqdm import tqdm
index = get_faiss_index(d=256, device=self.device)

similarities = {t: defaultdict(lambda: np.zeros(len(database), dtype=np.float16)) for t in self.thresholds}
for query_idx, iterator in tqdm(iterators):
    matches = np.zeros(len(database), dtype=np.float16)
    for query_data, (database_idx, database_data) in iterator:
        if (query_data is None) or (database_data is None):
            continue

        index.reset()
        index.add(database_data)
        score, idx = index.search(query_data, k=2)
        with np.errstate(divide='ignore'):
            ratio = score[:, 0] / score[:, 1]
        for t in self.thresholds:
            similarities[t][query_idx][database_data] = np.sum(ratio < t)



0


## V2

In [173]:
database = load_features('temp/stripespotter_sp_features.pickle')[:500]
query = load_features('temp/stripespotter_sp_features.pickle')[:50]


chunk_q_total = 1
chunk_q = 1
chunk_d_total = 1
chunk_d = 1

q_subset = np.array_split(np.arange(len(query)), chunk_q_total)[chunk_q-1]
d_subset = np.array_split(np.arange(len(database)), chunk_d_total)[chunk_d-1]

iterator = itertools.product(enumerate(q_subset), enumerate(d_subset))
iterator_size = len(q_subset)*len(d_subset)
iterator = add_iterator_data(iterator, query, database)


batch_size = None
if batch_size:
    iterator = batched(iterator, batch_size)


similarity = {t: np.zeros((len(query), len(database)), dtype=np.float16) for t in self.thresholds}



for (q_idx, q_data), (d_idx, d_data) in tqdm(iterator, total=iterator_size, mininterval=1):
    if (d_data is None) or (q_data is None):
        for t in self.thresholds:
            similarity[t][q_idx][d_idx] = 0
    else:
        index.reset()
        index.add(d_data)
        score, idx = index.search(q_data, k=2)
        with np.errstate(divide='ignore'):
            ratio = score[:, 0] / score[:, 1]

        for t in self.thresholds:
            similarity[t][q_idx][d_idx] = np.sum(ratio < t)

### Number of LOFT correspondences

In [17]:
'''
import sys
sys.path.append('/home/cermavo3/projects/wildlife-tools/models')
sys.path.append('/home/cermavo3/projects/wildlife-tools')
from matchers import LOFTRMatcher
matcher = LOFTRMatcher(thresholds=[0.73])
dataset = WildlifeDataset(metadata.iloc[:50], root=root, transform=transform)
sim_v1 = matcher.train(dataset_query=dataset, dataset_database=dataset)

'''

In [19]:
import torch
import pandas as pd
from data.dataset import WildlifeDataset
from torchvision import transforms as T

root = '/home/cermavo3/projects/wildlife-experiments/data_256/StripeSpotter'
metadata = pd.read_csv(f'{root}/annotations.csv')

transform = T.Compose([
    T.Resize(size=256),
    T.CenterCrop(size=(224, 224)),
    T.Grayscale(),
    T.ToTensor(),
])


In [None]:
'matcher': {
    'method': 'descriptor',
    'device': 'cuda',
    'num_dim': 256,
}

    
pipeline: [
    'load_data',
    'matcher',
]

In [None]:
'matcher': {
    'method': 'descriptor',
    'device': 'cuda',
    'num_dim': 256,
}

'query_dataset': {
    'method': 'folder',
    'root': '/home/cermavo3/projects/wildlife-experiments/data_256/StripeSpotter',
    'transform': {
    }
}

pipeline: [
    'load_data',
    'matcher' {
        'query' : {
            'query_dataset'
        }
    },
]

In [None]:
# Match -> similarity
# Inference

## V2

In [100]:
def add_iterator_data(iterator, data_a, data_b):
    ''' Convert data-index to data-point given pair iterator and source data. '''
    for (a_idx, a), (b_idx, b) in iterator:
        yield (a_idx, data_a[a]), (b_idx, data_b[b])
        

def get_indexes(array, chunk, chunk_total):
    return np.array_split(np.arange(len(features)), chunk_total)[chunk-1]

In [None]:
matcher = LOFTRMatcher(device='cuda', thresholds=[0.73])


dataset = WildlifeDataset(metadata.iloc[:50], root=root, transform=transform)

# Inputs

d_chunk = 1
d_chunk_total = 1
d_subset = np.array_split(np.arange(len(database)), chunk_d_total)[chunk_d-1]
database = [i[0] for i in dataset]

query = [i[0] for i in dataset]
q_chunk = 1
q_chunk_total = 1
q_subset = np.array_split(np.arange(len(query)), chunk_q_total)[chunk_q-1]


In [None]:
class LoadedDataset():
    

In [97]:

import kornia.feature as KF
class LOFTRMatcher():
    def __init__(
        self,
        device: str ='cuda',
        pretrained: str ='outdoor',
        thresholds: tuple[float] = (0.99, ),
        batch_size: int = 128,
    ):
        self.device = device
        self.matcher = KF.LoFTR(pretrained=pretrained).to(device)
        self.thresholds = thresholds
        self.batch_size = batch_size


    def match(self, query, database):
        iterator = batched(itertools.product(enumerate(query), enumerate(database)), self.batch_size)
        iterator_size = int(np.ceil(len(query)*len(database) / self.batch_size))
        similarity = {t: np.full((len(query), len(database)), np.nan, dtype=np.float16) for t in self.thresholds}

        for pair_batch in tqdm(iterator, total=iterator_size, mininterval=1):
            q, d = zip(*pair_batch)
            q_idx, q_data = list(zip(*q))
            d_idx, d_data = list(zip(*d))
            input_dict = {
                "image0": torch.stack(a_data).to(self.device),
                "image1": torch.stack(b_data).to(self.device),
            }
            with torch.inference_mode():
                correspondences = self.matcher(input_dict)

            batch_idx = correspondences['batch_indexes'].cpu().numpy()
            confidence = correspondences['confidence'].cpu().numpy()
            for t in self.thresholds:
                series = pd.Series(confidence > t)
                for j, group in series.groupby(batch_idx):
                    similarity[t][q_idx[j], d_idx[j]] = group.sum()
        return similarity


100%|██████████| 20/20 [00:21<00:00,  1.09s/it]


In [95]:
for (q_store_idx, q_idx), (d_store_idx, d_idx) in iterator:
    break

In [84]:
q, d = zip(*pair_batch)
q_idx, q_data = list(zip(*[(i, query[j]) for i, j in q]))
b_idx, database_data = list(zip(*[(i, database[j]) for i, j in d]))

In [98]:
(matcher.similarity[0.73] == similarity[0.73]).all()

True

In [None]:
#query
load_descriptors()
Chunks

database
#load_descriptor()
Chunks

DescriptorMatcher()