In [0]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np

from matplotlib import pyplot as plt
%matplotlib inline

import random
from pathlib import Path

from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.metrics.pairwise import manhattan_distances

import sys

from hungarian import Hungarian

seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [3]:
# Detect if we are in Google Colaboratory
try:
    import google.colab
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

from pathlib import Path
# Determine the locations of auxiliary libraries and datasets.
if IN_COLAB:
    google.colab.drive.mount("/content/drive")
    
    # Change this if you created the shortcut in a different location
    AUX_DATA_ROOT = Path("/content/drive/My Drive/project dl")
    
    assert AUX_DATA_ROOT.is_dir(), "Have you forgotten to 'Add a shortcut to Drive'?"
    
    import sys
    sys.path.insert(0, str(AUX_DATA_ROOT))
else:
    AUX_DATA_ROOT = Path(".")

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
!rm -r embeddings

In [0]:
!unzip -q "{AUX_DATA_ROOT / 'embeddings.zip'}"

In [0]:
!unzip -q "{AUX_DATA_ROOT / 'embeddings_normalized.zip'}"
!unzip -q "{AUX_DATA_ROOT / 'edge_embeddings.zip'}"
!unzip -q "{AUX_DATA_ROOT / 'adjacency_matrices.zip'}"

In [0]:
class BirdsGraphDataset(Dataset):
    def __init__(self, use_graphwave=False):
        super().__init__()

        self.use_graphwave = use_graphwave
    
    def __len__(self):
        return 11787
    
    def __getitem__(self, pair):
        i,j = pair
        # i, j = 0, 0
        # while i == j:
        #     i, j = np.random.randint(0, 11786, 2)
            
        with open(f'embeddings/{i}', 'rb') as fp:
            H1 = torch.load(fp, map_location=torch.device('cpu')).squeeze()
        with open(f'edge_embeddings/{i}', 'rb') as fp:
            E1 = torch.load(fp, map_location=torch.device('cpu')).squeeze()
        with open(f'adjacency_matrices/{i}.npy', 'rb') as fp:
            A1 = np.load(fp)
      
        
        with open(f'embeddings/{j}', 'rb') as fp:
            H2 = torch.load(fp, map_location=torch.device('cpu')).squeeze()
        with open(f'edge_embeddings/{j}', 'rb') as fp:
            E2 = torch.load(fp, map_location=torch.device('cpu')).squeeze()
        with open(f'adjacency_matrices/{j}.npy', 'rb') as fp:
            A2 = np.load(fp)
            
        if self.use_graphwave:
            with open(f'embeddings_graphwave/{i}', 'rb') as fp:
                G1 = np.load(fp)
            with open(f'embeddings_graphwave/{j}', 'rb') as fp:
                G2 = np.load(fp)
            H1 = torch.cat([torch.tensor(H1), torch.tensor(G1)], dim=-1)
            H2 = torch.cat([torch.tensor(H2), torch.tensor(G2)], dim=-1)
            
        visible1 = A1.sum(axis=1)
        visible2 = A2.sum(axis=1)
        
        mask = ((visible1 != 0) & (visible2 != 0))
        
        return H1, E1, H2, E2, torch.tensor(mask)

In [0]:
Birds = BirdsGraphDataset() 

In [0]:
test_idx_pairs = np.loadtxt('drive/My Drive/project dl/train_test_indexes/test_idx_pairs.txt').astype(int)
np.shape(test_idx_pairs)

test_idx_pairs = np.where(test_idx_pairs == 11788, 11003, test_idx_pairs)

## Mutual check NN


In [0]:
# Given Birds custom dataset and test_idx_pairs this function calculates the
# accuracy of NN matcher with mutual check constraint (one keypoint could be
# matched only with one keypoint, implemented with Hungarian algorithm)

# masked_accuracy shows how many keypoints were classified right among visible ones
# 'Some sort of needed acuracy' means the aforementioned right matches + # of invisible
# points that were matched to other invisible points (because we do not care about them,
# this values will be masked)
def nn_mutual_accuracy(dataset, test_idx_pairs):
    true_matches = 0
    true_matches_masked = 0
    matches_masked = 0
    sample_size = np.shape(test_idx_pairs)[0]
    for k in range(sample_size):
        i,j = test_idx_pairs[k,:]
        H1, E1, H2, E2, mask = dataset[(i,j)]

        n_keypoints = H1.shape[0]

        kernel = np.abs(cosine_similarity(H1.detach().numpy(), H2.detach().numpy()))

        hungarian = Hungarian(kernel, is_profit_matrix=True)
        hungarian.calculate()

        idx = hungarian.get_results()
        idx = torch.tensor(sorted(idx, key=lambda x: x[0]))

        true_match = (idx[mask,0] == idx[mask,1]).sum()

        a = idx[~mask,0]
        b = idx[~mask,1]

        not_seen_match = a.view(1,-1).eq(b.view(-1,1)).sum(0).sum().item()

        true_matches_masked += true_match.item()
        matches_masked += mask.sum().item()
        true_matches += true_match.item() + not_seen_match#(n_keypoints - mask.sum().item())

        if k % 1000 == 0:
            print(k)

    return true_matches_masked / matches_masked, true_matches / (n_keypoints*sample_size)

In [14]:
acc_masked, acc = nn_mutual_accuracy(Birds, test_idx_pairs)
print()
print('Accuracy for masked elements: %.3f' %acc_masked)
print('Sort of needed accuracy:      %.3f' %acc)

0
1000
2000
3000
4000
5000
6000
7000
8000
9000

Accuracy for masked elements: 0.097
Sort of needed accuracy:      0.271


## Thresholded NN (for 2nd NN)

In [0]:
# Given Birds custom dataset, test_idx_pairs and threshold value this function calculates the
# accuracy of NN matcher with threshold constraint (we find two nearest neighbors
# and if the ratio of their respective distances is greater than threshold value 
# ---> we match it into some invisible keypoint)

def evaluate_accuracy_thr(dataset, test_idx_pairs, thr, metric='cos'):
    true_matches = 0
    true_matches_masked = 0
    matches_masked = 0
    sample_size = np.shape(test_idx_pairs)[0]
    for k in range(sample_size):
        i,j = test_idx_pairs[k,:]
        H1, E1, H2, E2, mask = dataset[(i,j)]
        n_keypoints = H1.shape[0]
        
        if metric == 'cos':
            kernel = np.abs(cosine_similarity(H1.detach(), H2.detach()))
        elif metric == 'l2':
            kernel = euclidean_distances(H1.detach(), H2.detach())
        elif metric == 'l1':
            kernel = manhattan_distances(H1.detach(), H2.detach())

        k = kernel.shape[0]

        val, idx = torch.topk(torch.tensor(kernel), 2, dim=0)

        val = val[1,:] / val[0,:]

        val = (val < thr)

        idx[0,~val] = torch.arange(0,n_keypoints)[~mask][0]
        idx = idx[0]

        indices = torch.ones((2,n_keypoints), dtype=int)
        indices[0,:] = torch.arange(0,n_keypoints)
        indices[1,:] = idx

        true_match = (indices[0,mask] == indices[1,mask]).sum()

        a = indices[0,~mask]
        b = indices[1,~mask]

        not_seen_match = a.view(1,-1).eq(b.view(-1,1)).sum(0).sum().item()

        true_matches_masked += true_match.item()
        matches_masked += mask.sum().item()
        true_matches += true_match.item() + not_seen_match

        # if i % 500 == 0:
        #     print(true_matches)

    return true_matches_masked / matches_masked, true_matches / (n_keypoints*sample_size)

In [22]:
thr = 0.9

acc_masked, acc = evaluate_accuracy_thr(Birds, test_idx_pairs, thr, metric='cos')
print()
print('Cosine similarity:')
print('Accuracy for masked elements: %.3f' %acc_masked)
print('Sort of needed accuracy:      %.3f' %acc)


Cosine similarity:
Accuracy for masked elements: 0.037
Sort of needed accuracy:      0.323


In [23]:
thr = 0.9

acc_masked, acc = evaluate_accuracy_thr(Birds, test_idx_pairs, thr, metric='l1')
print()
print('L1 distance:')
print('Accuracy for masked elements: %.3f' %acc_masked)
print('Sort of needed accuracy:      %.3f' %acc)


L1 distance:
Accuracy for masked elements: 0.015
Sort of needed accuracy:      0.235


In [24]:
thr = 0.9

acc_masked, acc = evaluate_accuracy_thr(Birds, test_idx_pairs, thr, metric='l2')
print()
print('L2 distance:')
print('Accuracy for masked elements: %.3f' %acc_masked)
print('Sort of needed accuracy:      %.3f' %acc)


L2 distance:
Accuracy for masked elements: 0.024
Sort of needed accuracy:      0.248
