# Footwear Impression Matching - Visualization and Analysis

This notebook provides visualization and analysis tools for the footwear impression matching system. It allows you to:
1. Explore the dataset
2. Visualize augmentations
3. Analyze model results
4. Examine failure cases

Before running this notebook, make sure you have processed the dataset and trained a model.

In [None]:
# Import libraries
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import cv2
from PIL import Image
from tqdm import tqdm

# Add parent directory to path for imports
sys.path.append('..')

# Import project modules
from data.dataloader import get_transforms, ShoeImpressDataset
from data.augmentation import FootwearAugmenter
from models.network import FootwearMatchingNetwork
from utils.common import calculate_metrics, plot_precision_recall_curve, plot_roc_curve

## 1. Dataset Exploration

Let's start by exploring the dataset structure and visualizing some examples.

In [None]:
# Set paths (modify as needed)
data_dir = '../data/processed'
train_csv = os.path.join(data_dir, 'train_pairs.csv')
val_csv = os.path.join(data_dir, 'val_pairs.csv')

# Load pair information
train_pairs = pd.read_csv(train_csv)
val_pairs = pd.read_csv(val_csv)

print(f"Training pairs: {len(train_pairs)}")
print(f"Validation pairs: {len(val_pairs)}")

# Class distribution
train_pos = (train_pairs['label'] == 1).sum()
train_neg = (train_pairs['label'] == 0).sum()
val_pos = (val_pairs['label'] == 1).sum()
val_neg = (val_pairs['label'] == 0).sum()

print(f"\nClass distribution:")
print(f"Training: {train_pos} positive ({train_pos/len(train_pairs)*100:.1f}%), {train_neg} negative ({train_neg/len(train_pairs)*100:.1f}%)")
print(f"Validation: {val_pos} positive ({val_pos/len(val_pairs)*100:.1f}%), {val_neg} negative ({val_neg/len(val_pairs)*100:.1f}%)")

In [None]:
# Visualize some examples
def show_pair(track_path, ref_path, label):
    track_img = cv2.imread(track_path)
    ref_img = cv2.imread(ref_path)
    
    # Convert from BGR to RGB
    track_img = cv2.cvtColor(track_img, cv2.COLOR_BGR2RGB)
    ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2RGB)
    
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    
    ax[0].imshow(track_img)
    ax[0].set_title("Track (Crime Scene) Impression")
    ax[0].axis('off')
    
    ax[1].imshow(ref_img)
    ax[1].set_title("Reference Impression")
    ax[1].axis('off')
    
    match_status = "Match" if label == 1 else "Non-Match"
    plt.suptitle(f"{match_status} Pair", fontsize=16)
    plt.tight_layout()
    plt.show()

# Show some positive pairs
positive_pairs = train_pairs[train_pairs['label'] == 1].sample(3)
print("Positive Pairs (Same footwear):\n")
for _, row in positive_pairs.iterrows():
    print(f"Track ID: {row['track_id']}, Reference ID: {row['ref_id']}")
    show_pair(row['track_path'], row['ref_path'], row['label'])

# Show some negative pairs
negative_pairs = train_pairs[train_pairs['label'] == 0].sample(3)
print("Negative Pairs (Different footwear):\n")
for _, row in negative_pairs.iterrows():
    print(f"Track ID: {row['track_id']}, Reference ID: {row['ref_id']}")
    show_pair(row['track_path'], row['ref_path'], row['label'])

## 2. Data Augmentation Visualization

Let's visualize the different augmentation techniques used in the system.

In [None]:
# Create augmenter
augmenter = FootwearAugmenter()

# Sample a track image
sample_row = train_pairs[train_pairs['label'] == 1].iloc[0]
track_path = sample_row['track_path']
track_img = cv2.imread(track_path)
track_img = cv2.cvtColor(track_img, cv2.COLOR_BGR2RGB)

# Generate augmentations
augmented_batch = augmenter.create_augmentation_batch(track_img, num_variants=8)

# Visualize original and augmentations
plt.figure(figsize=(20, 12))

for i, (img, aug_types) in enumerate(augmented_batch):
    plt.subplot(3, 3, i+1)
    plt.imshow(img)
    plt.title(", ".join(aug_types))
    plt.axis('off')

plt.tight_layout()
plt.show()

## 3. Model Analysis

Now, let's load a trained model and analyze its performance.

In [None]:
# Load model
def load_model(model_path, backbone='resnet50', feature_dim=256, device='cuda'):
    # Set device
    device = torch.device(device if torch.cuda.is_available() else 'cpu')
    
    # Create model
    model = FootwearMatchingNetwork(
        backbone=backbone,
        pretrained=False,
        feature_dim=feature_dim
    )
    
    # Load weights
    state_dict = torch.load(model_path, map_location=device)
    
    # Handle different checkpoint formats
    if isinstance(state_dict, dict) and 'model_state_dict' in state_dict:
        state_dict = state_dict['model_state_dict']
    
    model.load_state_dict(state_dict)
    model = model.to(device)
    model.eval()
    
    return model, device

# Set paths (modify as needed)
model_path = '../results/checkpoints/best_model.pth'

# Try to load the model (will fail if the path is incorrect)
try:
    model, device = load_model(model_path)
    print(f"Model loaded successfully! Using device: {device}")
except Exception as e:
    print(f"Could not load model: {str(e)}")
    print("Please set the correct model path.")

In [None]:
# Evaluate on validation set
def evaluate_model(model, val_csv, device, batch_size=32, img_size=512):
    # Setup transforms
    transform = get_transforms(mode='val', img_size=img_size)
    
    # Create dataset
    val_dataset = ShoeImpressDataset(
        val_csv,
        transform=transform,
        triplet_mode=False,
        online_augment=False
    )
    
    # Create data loader
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    print(f"Evaluating on {len(val_dataset)} validation pairs")
    
    # Collect all outputs and targets
    all_logits = []
    all_similarities = []
    all_labels = []
    
    # Evaluate
    with torch.no_grad():
        for track_imgs, ref_imgs, labels, _ in tqdm(val_loader):
            # Move to device
            track_imgs = track_imgs.to(device)
            ref_imgs = ref_imgs.to(device)
            
            # Forward pass
            logits, similarities, _, _ = model(track_imgs, ref_imgs)
            
            # Collect outputs
            all_logits.append(logits.cpu().squeeze())
            all_similarities.append(similarities.cpu())
            all_labels.append(labels)
    
    # Concatenate all outputs and targets
    all_logits = torch.cat(all_logits).numpy()
    all_similarities = torch.cat(all_similarities).numpy()
    all_labels = torch.cat(all_labels).numpy()
    
    # Calculate metrics
    metrics = calculate_metrics(all_logits, all_similarities, all_labels)
    
    return metrics, all_logits, all_similarities, all_labels

# Try to evaluate the model
try:
    metrics, logits, similarities, labels = evaluate_model(model, val_csv, device)
    
    # Print metrics
    print("\nEvaluation Results:")
    print(f"Accuracy: {metrics['accuracy']:.4f}")
    print(f"Average Precision: {metrics['ap']:.4f}")
    print(f"ROC AUC: {metrics['roc_auc']:.4f}")
    print(f"AP from Similarity: {metrics['sim_ap']:.4f}")
    
    # Plot curves
    plt.figure(figsize=(15, 6))
    
    plt.subplot(1, 2, 1)
    plt.plot(metrics['recall'], metrics['precision'], lw=2, marker='.', markersize=3)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'Precision-Recall Curve (AP = {metrics["ap"]:.4f})')
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    plt.plot(metrics['fpr'], metrics['tpr'], lw=2, marker='.', markersize=3)
    plt.plot([0, 1], [0, 1], 'k--', lw=1.5)  # Diagonal line
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curve (AUC = {metrics["roc_auc"]:.4f})')
    plt.grid(True)
    
    plt.tight_layout()
    plt.show()
    
except Exception as e:
    print(f"Evaluation failed: {str(e)}")

## 4. Analysis of Failure Cases

Let's analyze some failure cases to understand the model's limitations.

In [None]:
# Analyze failure cases
def analyze_failures(logits, similarities, labels, val_csv, n_samples=5):
    # Load validation dataset
    val_pairs = pd.read_csv(val_csv)
    
    # Get probabilities
    probas = 1 / (1 + np.exp(-logits))  # Sigmoid
    predictions = (probas > 0.5).astype(int)
    
    # Find failure cases
    failure_indices = np.where(predictions != labels)[0]
    
    print(f"Found {len(failure_indices)} failure cases out of {len(labels)} samples ({len(failure_indices)/len(labels)*100:.2f}%)")
    
    # Analyze false positives and false negatives
    false_positives = [(i, probas[i]) for i in failure_indices if predictions[i] == 1 and labels[i] == 0]
    false_negatives = [(i, probas[i]) for i in failure_indices if predictions[i] == 0 and labels[i] == 1]
    
    print(f"False positives: {len(false_positives)}")
    print(f"False negatives: {len(false_negatives)}")
    
    # Sort by confidence
    false_positives.sort(key=lambda x: x[1], reverse=True)
    false_negatives.sort(key=lambda x: x[1])
    
    # Visualize top failures
    if false_positives:
        print("\nTop False Positives (Predicted Match, Actually Different):")
        for i, (idx, conf) in enumerate(false_positives[:n_samples]):
            pair = val_pairs.iloc[idx]
            print(f"Pair {i+1}: Track ID {pair['track_id']}, Reference ID {pair['ref_id']}, Confidence: {conf:.4f}")
            show_pair(pair['track_path'], pair['ref_path'], pair['label'])
    
    if false_negatives:
        print("\nTop False Negatives (Predicted Different, Actually Match):")
        for i, (idx, conf) in enumerate(false_negatives[:n_samples]):
            pair = val_pairs.iloc[idx]
            print(f"Pair {i+1}: Track ID {pair['track_id']}, Reference ID {pair['ref_id']}, Confidence: {conf:.4f}")
            show_pair(pair['track_path'], pair['ref_path'], pair['label'])

# Try to analyze failures
try:
    analyze_failures(logits, similarities, labels, val_csv)
except Exception as e:
    print(f"Analysis failed: {str(e)}")

## 5. Feature Visualization

Let's visualize the feature spaces to understand how the model separates different impressions.

In [None]:
# Extract embeddings
def extract_embeddings(model, val_csv, device, batch_size=32, img_size=512):
    # Setup transforms
    transform = get_transforms(mode='val', img_size=img_size)
    
    # Create dataset
    val_dataset = ShoeImpressDataset(
        val_csv,
        transform=transform,
        triplet_mode=False,
        online_augment=False
    )
    
    # Create data loader
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    print(f"Extracting embeddings for {len(val_dataset)} samples")
    
    # Collect embeddings and metadata
    track_embeddings = []
    ref_embeddings = []
    all_labels = []
    track_ids = []
    ref_ids = []
    
    # Extract embeddings
    with torch.no_grad():
        for track_imgs, ref_imgs, labels, meta in tqdm(val_loader):
            # Move to device
            track_imgs = track_imgs.to(device)
            ref_imgs = ref_imgs.to(device)
            
            # Get features
            track_features = model(track_imgs, None, mode='track')
            ref_features = model(None, ref_imgs, mode='ref')
            
            # Global pooling to get embeddings
            track_emb = F.adaptive_avg_pool2d(track_features, 1).squeeze(-1).squeeze(-1)
            ref_emb = F.adaptive_avg_pool2d(ref_features, 1).squeeze(-1).squeeze(-1)
            
            # Collect data
            track_embeddings.append(track_emb.cpu())
            ref_embeddings.append(ref_emb.cpu())
            all_labels.append(labels)
            
            # Collect metadata
            track_ids.extend([m['track_id'] for m in meta])
            ref_ids.extend([m['ref_id'] for m in meta])
    
    # Concatenate everything
    track_embeddings = torch.cat(track_embeddings).numpy()
    ref_embeddings = torch.cat(ref_embeddings).numpy()
    all_labels = torch.cat(all_labels).numpy()
    
    return track_embeddings, ref_embeddings, all_labels, track_ids, ref_ids

# Try to extract embeddings
try:
    track_embs, ref_embs, emb_labels, track_ids, ref_ids = extract_embeddings(model, val_csv, device)
    print(f"Extracted {len(track_embs)} embeddings with {track_embs.shape[1]} dimensions")
except Exception as e:
    print(f"Embedding extraction failed: {str(e)}")

In [None]:
# Visualize embeddings with dimensionality reduction
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

def visualize_embeddings(track_embs, ref_embs, labels, method='tsne', n_components=2):
    # Combine track and reference embeddings
    all_embs = np.vstack([track_embs, ref_embs])
    
    # Apply dimensionality reduction
    if method == 'tsne':
        reducer = TSNE(n_components=n_components, random_state=42)
    else:  # PCA
        reducer = PCA(n_components=n_components, random_state=42)
    
    reduced_embs = reducer.fit_transform(all_embs)
    
    # Split back into track and reference
    track_reduced = reduced_embs[:len(track_embs)]
    ref_reduced = reduced_embs[len(track_embs):]
    
    # Create visualization
    plt.figure(figsize=(12, 10))
    
    # Plot by match/non-match
    pos_idx = (labels == 1)
    neg_idx = (labels == 0)
    
    # Plot track points
    plt.scatter(track_reduced[pos_idx, 0], track_reduced[pos_idx, 1], c='blue', marker='o', s=50, alpha=0.7, label='Track (Match)')
    plt.scatter(track_reduced[neg_idx, 0], track_reduced[neg_idx, 1], c='red', marker='o', s=50, alpha=0.7, label='Track (Non-Match)')
    
    # Plot reference points
    plt.scatter(ref_reduced[pos_idx, 0], ref_reduced[pos_idx, 1], c='cyan', marker='s', s=50, alpha=0.7, label='Reference (Match)')
    plt.scatter(ref_reduced[neg_idx, 0], ref_reduced[neg_idx, 1], c='magenta', marker='s', s=50, alpha=0.7, label='Reference (Non-Match)')
    
    # Draw lines between matching pairs
    for i in range(len(track_reduced)):
        if labels[i] == 1:  # Only for matching pairs
            plt.plot([track_reduced[i, 0], ref_reduced[i, 0]], 
                     [track_reduced[i, 1], ref_reduced[i, 1]], 
                     'k-', alpha=0.3)
    
    plt.title(f'{method.upper()} Visualization of Footwear Impression Embeddings', fontsize=16)
    plt.xlabel(f'{method.upper()}-1', fontsize=14)
    plt.ylabel(f'{method.upper()}-2', fontsize=14)
    plt.legend(fontsize=12)
    plt.grid(True, alpha=0.3)
    plt.show()

# Try to visualize embeddings
try:
    print("\nT-SNE Visualization:")
    visualize_embeddings(track_embs, ref_embs, emb_labels, method='tsne')
    
    print("\nPCA Visualization:")
    visualize_embeddings(track_embs, ref_embs, emb_labels, method='pca')
except Exception as e:
    print(f"Visualization failed: {str(e)}")

## 6. Conclusion

In this notebook, we have:
1. Explored the footwear impression dataset
2. Visualized augmentation techniques
3. Analyzed model performance
4. Examined failure cases
5. Visualized the embedding space

These insights can help us understand how the model works, what its limitations are, and how we might improve it in the future.