pip install torchtext

In [1]:
import torch
from torchtext.vocab import GloVe



In [14]:
class FilterPredictions:
    def __init__(self):
        self.glove = GloVe(name='6B', dim=100)

    # def get_word_embedding(self, word):
    #     return self.glove(word)
    
    def get_sentence_embedding(self, phrase: str):
        words = phrase.split()
        embeddings = [self.glove[word] for word in words if word in self.glove.stoi]
        if embeddings:
            return torch.mean(torch.stack(embeddings), dim=0)
        else:
            return torch.zeros(self.glove.dim)

    def closest_prediction(
        self,
        predictions: list[str], 
        truth: str, 
        ):
        similarities = {}
        for prediction in predictions:
            prediction_embeddings = self.get_sentence_embedding(prediction)
            truth_embeddings = self.get_sentence_embedding(truth)

            similarity = torch.nn.functional.cosine_similarity(prediction_embeddings, truth_embeddings, dim=0)
            similarities[prediction] = similarity
    
        closest_similarity = max(similarities.values())
        index_of_closest_match = list(similarities.values()).index(closest_similarity)
        closest_caption = predictions[index_of_closest_match]

        return similarities, {'index': index_of_closest_match, 'caption': closest_caption, 'similarity': closest_similarity}

In [15]:
filter = FilterPredictions()

In [16]:
truth = "red helicopter"
predictions = ["red jet", "gray helicopter", "red helicopter"]

In [17]:
filter.closest_prediction(predictions=predictions, truth=truth)

({'red jet': tensor(0.8939),
  'gray helicopter': tensor(0.8951),
  'red helicopter': tensor(1.0000)},
 {'index': 2, 'caption': 'red helicopter', 'similarity': tensor(1.0000)})