In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import ConnectionPatch
import pickle as pkl

from dmd_api import DmdExtractor, DmdMatcher
import dmd_utils as dmdu

# Dados

In [None]:
# Example
img, mnt = dmdu.load_dmd_format('../datasets/sd258.pkl', item_id=1)
dmdu.plot_mnt(img, mnt, text=True)

In [None]:
total = len(dmdu.FNAMES)
print(f'Total number of samples: {total}')
all_idx = list(range(total))

queries_idx = [i for i in all_idx if i % 2 == 0]
gallery_idx = [i for i in all_idx if i % 2 == 1]

query_imgs, query_mnts = [], []
for i in tqdm(queries_idx):
    img, mnt = dmdu.load_dmd_format('../datasets/sd258.pkl', i)
    query_imgs.append(img)
    query_mnts.append(mnt)

gallery_imgs, gallery_mnts = [], []
for i in tqdm(gallery_idx):
    img, mnt = dmdu.load_dmd_format('../datasets/sd258.pkl', i)
    gallery_imgs.append(img)
    gallery_mnts.append(mnt)

# Extração

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
extractor = DmdExtractor('../logs/DMD++/best_model.pth.tar', device=device)

EXTRACT = False

if EXTRACT:
    # Extract templates
    query_templates = []
    for img, mnt in tqdm(zip(query_imgs, query_mnts), total=len(query_imgs)):
        template = extractor.extract(img, mnt)
        query_templates.append(template)

    gallery_templates = []
    for img, mnt in tqdm(zip(gallery_imgs, gallery_mnts), total=len(gallery_imgs)):
        template = extractor.extract(img, mnt)
        gallery_templates.append(template)

    # Save templates
    with open('query_templates.pkl', 'wb') as f:
        pkl.dump(query_templates, f)
    with open('gallery_templates.pkl', 'wb') as f:
        pkl.dump(gallery_templates, f)
else:
    # Load templates
    with open('query_templates.pkl', 'rb') as f:
        query_templates = pkl.load(f)
    with open('gallery_templates.pkl', 'rb') as f:
        gallery_templates = pkl.load(f)

# Identificação

In [None]:
CACHED = True
if CACHED:
    scores = np.load('scores.npy')
else:
    matcher = DmdMatcher()
    scores = matcher.identify(query_templates, gallery_templates)

    np.save('scores.npy', scores)

In [None]:
plt.imshow(scores, cmap='hot')
plt.ylim((0, 50))

# Avaliação

In [None]:
def get_cmc(scores, max_rank=50):
    num_q, num_g = scores.shape
    assert num_q <= num_g, "Number of queries should be less than or equal to number of gallery samples"
    
    error_cases = []
    cmc = np.zeros(max_rank)
    for i in range(num_q):
        # Get the indices that would sort the scores in descending order
        sorted_indices = np.argsort(-scores[i])
        
        # Note: Assuming the correct match is at the same index as the query
        correct_index = i
        rank = np.where(sorted_indices == correct_index)[0][0]
        if rank < max_rank:
            cmc[rank:] += 1
        
        if rank + 1 > 1:
            # Store the error case: query_idx, gallery_true_idx, gallery_predicted_idx, rank
            error_cases.append({
                'query_idx': i,
                'gallery_true_idx': correct_index,
                'gallery_predicted_idx': sorted_indices[0],
                'rank': rank + 1
            })

    
    cmc = cmc / num_q
    # Sort by rank
    error_cases = sorted(error_cases, key=lambda x: x['rank'])
    return cmc, error_cases

def plot_cmc(cmc):
    cmc = cmc * 100  # Convert to percentage
    plt.figure()
    plt.plot(np.arange(1, len(cmc) + 1), cmc)

    # Important ranks to highlight
    important_ranks = [1, 5, 10, 20]
    for rank in important_ranks:
        if rank <= len(cmc):
            plt.plot(rank, cmc[rank - 1], 'xr')  # Highlight with red dot
            plt.text(rank, cmc[rank - 1], f'Rank-{rank}: {cmc[rank - 1]:.2f}', fontsize=9, verticalalignment='bottom')
    plt.xlabel('Rank')
    plt.ylabel('Matching Rate')
    plt.title('CMC Curve')
    plt.grid(True)
    plt.show()

cmc, error_cases = get_cmc(scores, max_rank=256)
print(f'Misses: {len(error_cases)}')
plot_cmc(cmc)

# Visualização

In [None]:
def plot_minutiae_matches(imagem_query, imagem_gallery, template_query, template_gallery, pairs, max_score=10):
    # 1. Configura a figura e os eixos
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 7))

    # 2. Exibe as imagens
    ax1.imshow(imagem_query, cmap='gray')
    ax1.set_title('Query')
    ax1.axis('off')

    ax2.imshow(imagem_gallery, cmap='gray')
    ax2.set_title('Gallery')
    ax2.axis('off')

    # 3. Extrai as coordenadas das minúcias dos templates
    # Converte de tensor PyTorch (1, N, 3) para array numpy (N, 3)
    query_mnt = template_query['mnt'].squeeze(0).cpu().numpy()
    gallery_mnt = template_gallery['mnt'].squeeze(0).cpu().numpy()

    # 4. Prepara o mapa de cores se os scores forem fornecidos
    use_color_map = False
    if pairs and len(pairs[0]) == 3:
        use_color_map = True
        scores = [p[0] for p in pairs]
        #min_s, max_s = min(scores), max(scores)
        min_s, max_s = np.min(scores), np.max(scores)
        cmap = plt.get_cmap('viridis')
        
        # Adiciona uma legenda de barra de cores
        norm = plt.Normalize(vmin=min_s, vmax=max_s)
        sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
        fig.colorbar(sm, ax=[ax1, ax2], orientation='vertical', label='Score da Correspondência', shrink=0.8)

    # 5. Itera sobre os pares e desenha as linhas de conexão
    for pair in pairs:
        if use_color_map:
            score, idx_q, idx_g = pair
            # Normaliza o score para o intervalo [0, 1] para o mapa de cores
            norm_score = (score - min_s) / (max_s - min_s + 1e-6) # Evita divisão por zero
            color = cmap(norm_score)
        else:
            idx_q, idx_g = pair
            color = 'green' # Cor padrão se não houver score

        # Pega as coordenadas (x, y) de cada minúcia
        start_point = query_mnt[idx_q][:2]
        end_point = gallery_mnt[idx_g][:2]

        # Cria e adiciona o ConnectionPatch à figura
        con = ConnectionPatch(
            xyA=end_point,
            xyB=start_point,
            coordsA=ax2.transData,
            coordsB=ax1.transData,
            color=color,
            linewidth=1,
            alpha=norm_score
        )
        fig.add_artist(con)
        
        # Opcional: desenha círculos nas minúcias correspondentes
        ax1.plot(start_point[0], start_point[1], 'o', markerfacecolor=color, markeredgecolor='white', markersize=5)
        ax2.plot(end_point[0], end_point[1], 'o', markerfacecolor=color, markeredgecolor='white', markersize=5)

In [None]:
for ec in error_cases:
    print(f"Query idx: {ec['query_idx']}, Gallery true idx: {ec['gallery_true_idx']}, Gallery predicted idx: {ec['gallery_predicted_idx']}, Rank: {ec['rank']}")

In [None]:
# Error cases

eidx = 20  # Change this index to visualize different error cases
q_idx = error_cases[eidx]['query_idx']
g_idx = error_cases[eidx]['gallery_true_idx']  # True match
g_idx = error_cases[eidx]['gallery_predicted_idx'] # False match

query_img = query_imgs[q_idx]
gallery_img = gallery_imgs[g_idx]
query_templ = query_templates[q_idx]
gallery_templ = gallery_templates[g_idx]

final_score, pairs, scores_i, scores_f, sorted_indices, n_pair = dmdu.match_with_details(query_templ, gallery_templ)

pairs, scores_i, scores_f = pairs.squeeze(0), scores_i.squeeze(0), scores_f.squeeze(0)
sorted_indices, n_pair = sorted_indices.squeeze(0), n_pair.squeeze(0)

In [None]:
plot_pairs_i = []
plot_pairs_f = []
plot_pairs_s = [] # The true selected pairs
for pair, score_i, score_f in zip(pairs, scores_i, scores_f):
    plot_pairs_i.append((score_i, *pair))
    plot_pairs_f.append((score_f, *pair))

for i in sorted_indices[:n_pair]:
    pair = pairs[i]
    score_f = scores_f[i]
    plot_pairs_s.append((score_f, *pair))

plot_minutiae_matches(
    imagem_query=query_img,
    imagem_gallery=gallery_img,
    template_query=query_templ,
    template_gallery=gallery_templ,
    pairs=plot_pairs_i,
)
plot_minutiae_matches(
    imagem_query=query_img,
    imagem_gallery=gallery_img,
    template_query=query_templ,
    template_gallery=gallery_templ,
    pairs=plot_pairs_f,
)

# Select the final selected pairs
plot_minutiae_matches(
    imagem_query=query_img,
    imagem_gallery=gallery_img,
    template_query=query_templ,
    template_gallery=gallery_templ,
    pairs=plot_pairs_s,
)
plt.suptitle(f'Query idx: {q_idx}, Gallery idx: {g_idx}, Final Score: {float(final_score):.2f}', fontsize=16)