# Data Visualization Notebook

### Import Required Libraries and Packages

In [2]:
# -----------------------------
# 📦 Imports
# -----------------------------
import numpy as np
import matplotlib.pyplot as plt
import random

### Visualize Patches from Dataset

In [9]:
# -----------------------------
# 🔍 Patch Visualization Function
# -----------------------------
def visualize_labeled_patches(npz_file, samples_per_class=5):
    """
    Display N patch pairs for each label: positive (1), negative (0), hard negative (-1).
    
    Args:
        npz_file (str): Path to .npz file with 'patches' and 'labels'.
        samples_per_class (int): Number of patch pairs to show per class.
    """
    data = np.load(npz_file)
    patches = data['patches']  # shape: (N, 40, 40, 2)
    labels = data['labels']

    # Separate indices by label
    pos_indices = np.where(labels == 1)[0]
    neg_indices = np.where(labels == 0)[0]
    hard_neg_indices = np.where(labels == -1)[0]

    # Sample required number of indices per class
    pos_samples = np.random.choice(pos_indices, samples_per_class, replace=False)
    neg_samples = np.random.choice(neg_indices, samples_per_class, replace=False)
    hard_neg_samples = np.random.choice(hard_neg_indices, samples_per_class, replace=False)

    # Combine all samples for display
    all_samples = [
        (pos_samples, "Positive", "green"),
        (neg_samples, "Negative", "red"),
        (hard_neg_samples, "Hard Negative", "orange")
    ]

    # Plot grid: 3 rows (one per label), N columns
    fig, axes = plt.subplots(3, samples_per_class, figsize=(samples_per_class * 2, 6))
    fig.suptitle("Patch Pairs by Label", fontsize=16)

    for row, (indices, label_text, color) in enumerate(all_samples):
        for col, idx in enumerate(indices):
            patch_pair = patches[idx]
            anchor = patch_pair[:, :, 0]
            pair = patch_pair[:, :, 1]

            # Concatenate side-by-side
            combined = np.concatenate([anchor, pair], axis=1)

            ax = axes[row, col]
            ax.imshow(combined, cmap='gray')
            ax.set_title(label_text, color=color, fontsize=10)
            ax.axis('off')

    plt.tight_layout()
    plt.subplots_adjust(top=0.88)  # Make room for suptitle
    plt.show()

In [None]:
# -----------------------------
# ▶️ Call the function
# -----------------------------
# Example: Change path if needed
visualize_labeled_patches("../data/dataset/train_dataset.npz", samples_per_class=5)