In [86]:
import random
import torch
import torch.nn.functional as F

In [131]:
def GetCosineSimilarityDistance(image_index, img_vector_dict, similarity_value):
    
    """
    Calculates the cosine similarity distance between the tensors stored in the dictionary
    
    Args:
        image_index (string): the index of the image in the training dataframe which will be used as reference image
        img_vector_dict (dictionary): a dictionary with the row index of images as key and their belonging tensor as value
        similarity_value (float): ranging from 1 to -1, determines how similar the tensors must be. A value of 1 will return 
        1 tensor as only the reference tensor is equal to itself. A value of -1 will return all tensors. 
        
    Returns:
        index value of reference image (integer)
        index values of tensors which have a cosine similarity which is larger or equal the similarity value (list)
    """
    
    # raise an error if the similarity value does not range between 1 and -1
    if similarity_value >= 1 and similarity_value < -1:
        raise ValueError ('the similarity_value must range between 1 to -1, but got %f instead' %similarity_value)
    
    # get the tensor belonging the key and reshape it to the right dimensions
    reference_vector = img_vector_dict[image_index].unsqueeze(0)
    
    # create a list of all the keys in the dictionary
    list_of_keys = [key for key in img_vector_dict.keys()]
    
    # stack all the tensors 
    vector_stack = torch.stack([img_vector_dict[key] for key in list_of_keys])
    
    # calculate the cosine similarity between the reference vector and all other vectors
    cos_distance = F.cosine_similarity(vector_stack, reference_vector)
    
    # sort the cosine similarity
    cos_sim, index = cos_distance.sort(descending = True)
    
    # subset the images which fall within the range of the similarity value and convert the indices to a list
    img_key_indices = index[:len(cos_sim[cos_sim >= similarity_value])].tolist()
    
    # the indices represent the index of the tensor in the list of keys, these must be converted to the image indices
    img_indices = []
    for ix in img_key_indices:
        img_indices.append(list_of_keys[ix])
    
    return img_indices