# C1 W1 Group 8

In [None]:
from matplotlib import pyplot as plt
import numpy as np
from PIL import Image

from src.data import GT_QSD1_W1_LIST
from src.paths import BBDD_PATH, QSD1_W1_PATH

## Task 1 - Create Museum and query image descriptors (BBDD & QSD1)

The functions in this section are required to take a PIL.Image object as input and return a 1D descriptor in the form of a NumPy array. Inside the function, you have the freedom to implement any processing or transformation steps as long as the input and output types are respected. Specifically:

    Input: A PIL.Image object, which can be manipulated or processed as needed.
    Output: A 1D descriptor, represented as a NumPy array, which could be a histogram, feature vector, or any other type of descriptor derived from the input image.

In [None]:
def get_grayscale_histogram_descriptor(image: Image, bins: int = 256) -> np.array:
    histogram, _ = np.histogram(image.convert('L'), bins=bins, range=(0, bins-1))
    return histogram

In [None]:
# Replace with the function you want to use to generate the descriptors
get_descriptors_func = get_grayscale_histogram_descriptor

In [None]:
# Database image descriptors

database_image_descriptors = []
for database_image_path in BBDD_PATH.glob("*.jpg"):
    database_image_PIL = Image.open(database_image_path)
    descriptors = get_descriptors_func(database_image_PIL)
    database_image_descriptors.append(descriptors)
database_image_descriptors = np.array(database_image_descriptors)

In [None]:
# Query image descriptors

query_image_descriptors = []
for query_image_path in QSD1_W1_PATH.glob("*.jpg"):
    query_image_PIL = Image.open(query_image_path)
    descriptors = get_descriptors_func(query_image_PIL)
    query_image_descriptors.append(descriptors)
query_image_descriptors = np.array(query_image_descriptors)

In [None]:
# Create a plot for each histogram in the array
bins = np.arange(query_image_descriptors.shape[1])
plt.figure(figsize=(10, 6))
for i, hist in enumerate(query_image_descriptors):
    plt.plot(bins, hist, label=f'Descriptor {i+1}')

# Labeling the plot
plt.title('Descriptors')
plt.xlabel('Bins')
plt.ylabel('Frequency')
plt.legend(fontsize=5)
plt.grid(True)

# Show the plot
plt.show()

## Task 2 - Implement / compute similarity measures to compare images

In this section, the functions should implement various distance measures between query descriptors and database descriptors. The input for both the query and database will be a 2D NumPy array, where each row represents the descriptor of one image. Specifically:

    Query descriptors will have shape (N, K), where N is the number of query images, and K is the length of each descriptor.
    Database descriptors will have shape (M, K), where M is the number of database images, and K is the length of each descriptor.

The output should be a 2D array of shape (N, M), where each entry (i, j) represents the distance between query descriptor N_i and database descriptor M_j.

In [None]:
def compute_mse(query_descriptors: np.array, database_descriptors: np.array) -> np.array:
    # Compute pairwise squared differences
    differences = query_descriptors[:, np.newaxis, :] - database_descriptors[np.newaxis, :, :]
    squared_diff = np.square(differences)
    
    # Mean over the last dimension (K) to get MSE between each pair of points
    return squared_diff.mean(axis=2)

In [None]:
# Replace with the function you want to use to compute distances
compute_distance_func = compute_mse

In [None]:
query_distances = compute_distance_func(query_image_descriptors, database_image_descriptors)

In [None]:
plt.figure(figsize=(10, 8))
    
# Use a heatmap to display the distance matrix
normalized_query_distances = (query_distances - query_distances.min(axis=0)) / (query_distances.max(axis=0) - query_distances.min(axis=0))
plt.imshow(normalized_query_distances, aspect='auto')

# Add a color bar to show the scale of distances
plt.colorbar(label='Distance')

# Add labels for clarity
plt.xlabel('Database Images')
plt.ylabel('Query Images')
plt.title('Distance Matrix Between Query and Database Descriptors')

# Show the plot
plt.show()

## Task 3 - Implement retrieval system (retrieve top K results)

In [None]:
def get_topk_distances(distances: np.array, k: int = 1) -> tuple[list[list], list[list]]:
    # Get the indices of the top k minimum values for each row
    indices = np.argsort(distances, axis=1)[:, :k]
    
    # Gather the top k scores using the indices
    scores = np.take_along_axis(distances, indices, axis=1)
    
    return indices.tolist(), scores.tolist()

In [None]:
# Select number of results to be retrieved
k = 1

In [None]:
indices, scores = get_topk_distances(query_distances, k=k)

In [None]:
# Metrics



In [None]:
# Visualization
import matplotlib.image as mpimg

# Plot the images in a subplot
n = len(GT_QSD1_W1_LIST)  # Number of rows (ground truth images)

# Create a figure with n rows and k+1 columns (for the GT and N pred images)
fig, axs = plt.subplots(n, k + 2, figsize=(3 * (k + 2), 3 * n))

# Loop through each ground truth and predicted images
for i in range(n):
    # Load ground truth image
    query_image_path = QSD1_W1_PATH / f"{str(i).zfill(5)}.jpg"
    query_image_PIL = Image.open(query_image_path)

    # Plot ground truth image in the first column
    axs[i, 0].imshow(query_image_PIL)
    axs[i, 0].set_title(f"QUERY - {query_image_path.name}")
    axs[i, 0].axis('off')

    # Load ground truth image
    gt_idx = GT_QSD1_W1_LIST[i][0]  # only one index in GT
    database_image_path = BBDD_PATH / f"bbdd_{str(gt_idx).zfill(5)}.jpg"
    database_image_PIL = Image.open(database_image_path)

    # Plot ground truth image in the first column
    axs[i, 1].imshow(database_image_PIL)
    axs[i, 1].set_title(f"BBDD: {database_image_path.name}")
    axs[i, 1].axis('off')

    # Loop through the retrieved images
    for j in range(k):
        retrieved_idx = indices[i][j]  # 'k' indices retrieved
        retrieved_image_path = BBDD_PATH / f"bbdd_{str(retrieved_idx).zfill(5)}.jpg"
        retrieved_image_PIL = Image.open(retrieved_image_path)
        axs[i, j + 2].imshow(retrieved_image_PIL)
        axs[i, j + 2].set_title(f"PRED {j} - {retrieved_image_path.name}")
        axs[i, j + 2].axis('off')

# Adjust layout
plt.tight_layout()
plt.show()

## Task 4 - Create predictions for blind challenge (QST1)