In [None]:
# This file take random "N" number of query and and relevant images will from that random query class, 
# Only we need to give the databse path
# And hyperparameter

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install torch

In [None]:
!pip install Pillow

In [None]:
!pip install torchvision

In [2]:
!pip install mlxtend==0.17.0

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
!pip install pandas

In [None]:
!pip install scikit-image

In [64]:
import os
from PIL import Image
from mlxtend.frequent_patterns import fpgrowth
from torchvision import transforms
import torch
from torchvision.models import vgg16
import glob
import torch.nn as nn
from skimage import measure
import matplotlib.pyplot as plt
import pandas as pd
import random


In [65]:
# Hyper_parameter

no_random_images = 2 # random number of images N

top_no_image_print = 6  # What number of top images to be shown in image, Note it should be less than K

alpha = 0.1 # α is used for balancing the effect of global and local features, Range - {0, 0.01, 0.1, · · · , 100}

k = 5 # Get top K images

minsupp = 2 # minimum support threshold for FPM, The range of minsupp are {0, 1, 2, · · · , 10}

In [66]:

# Define paths  database image folders
database_image_folder = r"/content/drive/MyDrive/Colab Notebooks/database_new"


In [67]:


class GlobalFeature:
    def __init__(self, vgg16):
        self.vgg16 = vgg16

    def extract(self, image):
        with torch.no_grad():
            features = self.vgg16.features(image)

        # Obtain the salient object by performing a mask operation
        features = features.squeeze(0)
        A = features.sum(dim=0)
        threshold = A.mean()
        mask = (A > threshold).float()

        # Retain the largest connected component using the flood fill algorithm
        labels = measure.label(mask.cpu().numpy())
        largest_label = labels.max()
        if largest_label > 0:
            largest_area = 0
            for i in range(1, largest_label + 1):
                area = (labels == i).sum()
                if area > largest_area:
                    largest_area = area
                    largest_component = i

        for i in range(mask.shape[0]):
            for j in range(mask.shape[1]):
                if labels[i, j] != largest_component:
                    mask[i, j] = 0

        mask = mask.unsqueeze(0).unsqueeze(0)
        salient_object = features * mask

        # Extract the global feature fG from the salient object
        fG_max_pooling = nn.functional.adaptive_max_pool2d(salient_object, (1, 1)).view(-1)
        fG_avg_pooling = nn.functional.adaptive_avg_pool2d(salient_object, (1, 1)).view(-1)
        fG = torch.cat((fG_max_pooling, fG_avg_pooling), dim=0)

        return fG

In [68]:


class LocalFeature:
    def __init__(self, vgg16):
        self.vgg16 = vgg16

    def extract(self, image):
        with torch.no_grad():
            features = self.vgg16.features(image)

        # Obtain the salient object by performing a mask operation
        features = features.squeeze(0)
        A = features.sum(dim=0)
        threshold = A.mean()
        mask = (A > threshold).float()
        mask = mask.unsqueeze(0).unsqueeze(0)
        salient_object = features * mask
        salient_object = salient_object.squeeze(0)

        # Convert feature maps and activated positions into transactions and items
        transactions = []
        for i in range(salient_object.shape[0]):
            feature_map = salient_object[i]
            activated_positions = (feature_map > 0).nonzero(as_tuple=True)
            items = [f'({x},{y})' for x, y in zip(*activated_positions)]
            transactions.append(items)

        # Mine frequent patterns using FPM
        
        I = sorted(set(item for transaction in transactions for item in transaction))
        df = pd.DataFrame([[int(item in transaction) for item in I] for transaction in transactions], columns=I)
        frequent_itemsets = fpgrowth(df, min_support=minsupp/len(transactions), use_colnames=True)

        # Extract the local feature fL from the frequent patterns
        patterns = torch.zeros_like(salient_object)
        for itemset in frequent_itemsets['itemsets']:
            for item in itemset:
                x, y = map(int, item.strip('()').split(','))
                patterns[:, x, y] = 1

        fL_max_pooling = nn.functional.adaptive_max_pool2d(patterns, (1, 1)).view(-1)
        fL_avg_pooling = nn.functional.adaptive_avg_pool2d(patterns, (1, 1)).view(-1)
        fL = torch.cat((fL_max_pooling, fL_avg_pooling), dim=0)

        return fL

In [69]:
def similarity_score(feature1, feature2):
    score = torch.dot(feature1, feature2) / (torch.norm(feature1) * torch.norm(feature2))
    return score.item()

In [70]:
def average_precision(retrieved_items, relevant_items):
    rel_count = 0
    precisions = []

    for i, item in enumerate(retrieved_items, start=1):
        if item in relevant_items:
            rel_count += 1
            precision_at_i = rel_count / i
            precisions.append(precision_at_i)

    if precisions:
        avg_precision = sum(precisions) / len(precisions)
    else:
        avg_precision = 0.0

    return avg_precision

In [71]:
def mean_average_precision(ap_scores):
    map_score = sum(ap_scores) / len(ap_scores)
    return map_score

In [72]:
vgg16_model = vgg16(pretrained=True)


In [73]:

# Define image transformation
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.Lambda(lambda image: image.convert('RGB') if image.mode != 'RGB' else image),
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [74]:

def load_and_transform_images(database_image_folder, transform):
    # Define image extensions
    image_extensions = ['jpg', 'png', 'jpeg']

    database_filenames = []
    database_images = []
    database_file_path = []

    # Iterate over all subdirectories and files in the root directory
    for subdir, dirs, files in os.walk(database_image_folder):
        for filename in files:
            # Check if the file is an image
            if filename.split('.')[-1].lower() in image_extensions:
                # Append the filename to the filenames list
                database_filenames.append(filename)
                # Open and transform the image, then append it to the images list
                file_path = os.path.join(subdir, filename)
                database_file_path.append(file_path)
                with Image.open(os.path.join(subdir, filename)) as image:
                    try:
                        tensor_image = transform(image)
                        database_images.append(tensor_image)
                    except Exception as e:
                        print(f"Error transforming image {filename}: {str(e)}")

    # Convert the list of images into a torch tensor
    try:
        database_images = torch.stack(database_images)
    except Exception as e:
        print(f"Error in stacking images: {str(e)}")

    return database_filenames, database_images, database_file_path

database_filenames, database_images, database_file_path  = load_and_transform_images(database_image_folder, transform)


In [75]:

def get_random_image(number_of_random_images, database_filenames, database_images, database_file_path):
    
    all_average_precision = []

    for i_random in range(number_of_random_images):

        query_image_path = random.choice(database_file_path)
        relevant_image_folder = os.path.dirname(query_image_path)
        query_image = Image.open(query_image_path)
        query_image = transform(query_image)

        def get_image_names_from_folder(root_folder):
            image_extensions = ['jpg', 'png', 'gif', 'jpeg']
            image_names = []

            for ext in image_extensions:
                image_paths = glob.glob(f'{root_folder}/**/*.{ext}', recursive=True)
                for path in image_paths:
                    image_name = os.path.basename(path)
                    image_names.append(image_name)

            return image_names

        relevant_images = get_image_names_from_folder(relevant_image_folder)


        
        global_extractor = GlobalFeature(vgg16_model)
        query_global_feature = global_extractor.extract(query_image)
        local_extractor = LocalFeature(vgg16_model)
        query_local_feature = global_extractor.extract(query_image)

        query_features = query_global_feature + alpha * query_local_feature


        all_similarity_score = []

        for i,j in zip(database_filenames, database_images):
            file_name = i
            image = j
            global_feature = global_extractor.extract(image)
            local_feature = global_extractor.extract(image)
            database_feature =  global_feature + alpha * local_feature
            get_similarity = similarity_score(database_feature, query_features)
            all_similarity_score.append(get_similarity)


        
        def get_top_k_indices(input_list, k):
            return sorted(range(len(input_list)), key=lambda i: input_list[i], reverse=True)[:k]



        


        top_k_indices = get_top_k_indices(all_similarity_score, k)

        top_k_results_path = [database_file_path[indices]  for indices in top_k_indices]
        top_k_results = [database_filenames[indices]  for indices in top_k_indices]

        one_average_precision = average_precision(top_k_results, relevant_images)



        all_average_precision.append(one_average_precision)



        # query image
        print("\n")
        print("Random Image Number -> ", i_random+1)

        img = Image.open(query_image_path)
        plt.imshow(img)
        plt.title("Query Images")
        plt.show()

        print("\n")
        print("\n")

        print("Average precision for this image", one_average_precision)


                # Open and display each image
        for i in range(min(top_no_image_print, k)):
            path = top_k_results_path[i]
            img = Image.open(path)
            plt.imshow(img)
            plt.title(f'Top {i+1} result')  # Add a title to the image
            plt.show()

        print("\n")
        print("\n")

        print("_________________________________________________________________________________________________________________________________________")
        print("\n")
        print("\n")




    return all_average_precision


In [None]:
""
ap = get_random_image(no_random_images,database_filenames, database_images, database_file_path )



In [63]:
map_ = mean_average_precision(ap)

print(map_)

0.5895833333333333
