In [49]:
import torch
from sklearn.metrics import DistanceMetric
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import numpy as np
from scipy.spatial import distance

import sys
sys.path.append('/workspaces/dbm25/task_1_2')
from tabulate import tabulate


from extract_features import extract_features



In [50]:
def top_k_distance_search(image_path, k, feature_model, measure):
    """
    Finds the k most similar images to an input image based on a feature model and distance metric.
    
    Parameters:
    ----------
    image_path : str
        Path to the input image.
    k : int
        Number of similar images to return.
    feature_model : str
        Feature extraction model identifier to use.
    measure : str
        Distance metric name (compatible with sklearn.metrics.DistanceMetric).
    
    Returns:
    -------
    list of dict
        List of k dictionaries containing similar image information:
        - "image_name": image name
        - "file_path": image path
        - "class": image class
        - "distance_score": distance score (lower = more similar)
    """
    
    # Load pre-trained ResNet50 model
    model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
    
    # Extract features from input image
    image_query_surrogate = extract_features(image_path, model)

    # Load feature space from database
    feature_space = torch.load("/workspaces/dbm25/data/extracted_features.pt")
    
    # Get and reshape (linearize) query image features
    image_query_feature = image_query_surrogate[feature_model].reshape(1,-1) 
    
    # Initialize distance metric
    dist = DistanceMetric.get_metric(measure)
    results = []

    # Compare query image against each image in the database
    for image_surrogate in feature_space:
        image_surrogate_feature = image_surrogate[feature_model].reshape(1,-1)
        
        # Calculate distance between query image and current database image
        distance_score = dist.pairwise(image_query_feature, image_surrogate_feature)[0][0]
        image_name = image_surrogate["file_path"].split("/").pop()

        new_result = {
            "image_name": image_name,
            "file_path": image_surrogate["file_path"],
            "class": image_surrogate["class"],
            "distance_score": distance_score
        }
        
        # Maintain a list of only the k most similar images
        if len(results) < k:
            results.append(new_result)
        else:
            # Sort by distance (descending - largest distance first)
            results.sort(key=lambda element: element["distance_score"], reverse=True)
            
            # Replace least similar image if current one is more similar
            if results[0]["distance_score"] > new_result["distance_score"]:
                results.pop(0)
                results.append(new_result)

    return results

In [51]:
def get_k_matching_label(image_path, k, k_neighbors, feature_model, measure):
   
   if k > 3:
      print("K must be less or equal 2!")
      return None

   vote_results = {
      "brain_glioma" : 0,
      "brain_menin" : 0,
      "brain_tumor" : 0
   }

   search_results = top_k_distance_search(image_path=image_path, 
                                k=k_neighbors, 
                                feature_model=feature_model,
                                measure = measure)
   
   for result in search_results:
      vote_results[result["class"]] += 1

   sortedResults = sorted(vote_results.items(), key=lambda x:x[1], reverse=True)


   return sortedResults[:k]


In [None]:
image_path = "/workspaces/dbm25/data/Part2/Part2/brain_tumor/brain_tumor_1006.jpg"
result = get_k_matching_label(image_path=image_path, k=2, k_neighbors=50, feature_model="cm", measure="chebyshev")

print(result)

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


Processing /workspaces/dbm25/data/Part2/Part2/brain_tumor/brain_tumor_1002.jpg
Failed to load image.


[ WARN:0@694.786] global loadsave.cpp:268 findDecoder imread_('/workspaces/dbm25/data/Part2/Part2/brain_tumor/brain_tumor_1002.jpg'): can't open/read file: check file path/integrity


TypeError: 'NoneType' object is not subscriptable