# Nearest words via embeddings
Here is the prompt to generate this code in ChatGPT
Write a python program using pytorch to create embeddings for a list of words. Then given the input of one of the words find the 5 closest words to it. 

Question - Are the embeddings normalized? ie are they of length 1 ?

## Exercise in Class
Add a capability to embed 5 phrases containing up to 30 words, 
then given a query select the phrase that might answer that query

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
nn

<module 'torch.nn' from '/home/codespace/.local/lib/python3.12/site-packages/torch/nn/__init__.py'>

In [7]:
# Define a list of words (vocabulary)
words = ["apple", "banana", "orange", "pear", "peach", 
         "mango", "grape", "cherry", "berry", "robot"]

# Create mappings from word to index and index to word
word2idx = {word: idx for idx, word in enumerate(words)}
idx2word = {idx: word for idx, word in enumerate(words)}

# Set embedding dimensions and create the embedding layer
embedding_dim = 50
embedding_layer = nn.Embedding(num_embeddings=len(words), embedding_dim=embedding_dim)

# Get embeddings for all words in the vocabulary
# (This will be a matrix of shape [vocab_size, embedding_dim])
embeddings = embedding_layer(torch.arange(len(words)))
# print the embedding for the first two words
print("Embeddings for the first two words:")
print(embeddings[:2])
# print the embedding for the last two words
print("Embeddings for the last two words:")
print(embeddings[-2:])
# print the embedding for the word "apple"
print("Embedding for the word 'apple':")
print(embeddings[word2idx["apple"]])



Embeddings for the first two words:
tensor([[ 0.3495,  0.0289,  0.2430, -1.0298,  0.9661, -0.2655, -0.2611,  0.5792,
          2.0585,  0.4058, -1.1492, -0.5025,  0.3413,  0.2476, -0.3349, -0.8631,
          0.4595,  0.8828,  1.0464,  0.8244, -0.5514,  1.0425,  1.5809, -0.0692,
         -1.7990, -0.0602,  0.3207,  1.4453, -0.4037, -1.3846,  0.7818, -0.5465,
         -0.7833,  0.0557,  0.5393,  0.1059, -0.9160, -1.7850, -0.1819,  1.7937,
         -0.1177, -0.3234,  2.3065,  0.1515,  1.2161, -0.5575, -0.0381, -0.3167,
         -0.0737, -0.3233],
        [-1.5515,  1.2578,  2.1755,  0.5963,  1.1017, -0.4185, -0.1118,  0.8138,
         -0.5492, -0.2341,  1.1922, -1.3007,  1.1381,  1.4680,  0.3003, -0.3165,
         -0.3893, -0.0549, -1.0157,  0.0590, -0.8525,  0.8212,  0.2667, -0.1814,
         -0.2477, -0.5772, -1.3342, -0.8670,  1.0715,  0.4388,  0.4143, -0.3818,
          0.8707,  0.9559,  1.0466, -0.8718, -0.9823,  1.0048, -1.4688, -0.1079,
          0.7671,  0.4526, -1.0224, -0.4266, 

In [8]:
def find_closest(word, top_k=5):
    """Finds the top_k closest words to the input word using cosine similarity."""
    if word not in word2idx:
        print(f"Word '{word}' not found in vocabulary.")
        return []
    
    # Get the embedding for the input word
    word_index = word2idx[word]
    word_embedding = embeddings[word_index]
    
    # Normalize all embeddings to unit length
    normalized_embeddings = F.normalize(embeddings, p=2, dim=1)
    normalized_word_embedding = F.normalize(word_embedding, p=2, dim=0)
    
    # Compute cosine similarities: dot product between normalized vectors
    similarities = torch.matmul(normalized_embeddings, normalized_word_embedding)
    
    # Get the indices of the top (top_k+1) similar words (including the word itself)
    top_values, top_indices = torch.topk(similarities, top_k + 1)
    
    results = []
    for value, idx in zip(top_values, top_indices):
        # Skip the word itself
        if idx.item() == word_index:
            continue
        results.append((idx2word[idx.item()], value.item()))
        if len(results) == top_k:
            break
    return results



In [9]:
# Example usage
input_word = "banana"
closest_words = find_closest(input_word)

if closest_words:
    print(f"Top {len(closest_words)} words similar to '{input_word}':")
    for word, similarity in closest_words:
        print(f"{word} (cosine similarity: {similarity:.4f})")

Top 5 words similar to 'banana':
berry (cosine similarity: 0.1695)
cherry (cosine similarity: 0.1219)
orange (cosine similarity: 0.0211)
robot (cosine similarity: 0.0143)
mango (cosine similarity: -0.0414)
