<a href="https://colab.research.google.com/github/Alwin-Lin/embeddingsNearestNeighborSearch/blob/main/Recommendation_using_embeddings_and_nearest_neighbor_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Recommendation using embeddings and nearest neighbor search


In [None]:
import pandas as pd
import pickle
import google.generativeai as genai
from typing import List
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
genai.configure(api_key='YOUR_GEMINI_API_KEY')

In [None]:
class EmbeddingUtils:
    @staticmethod
    def get_embedding(text: str, model="embedding-001") -> List[float]:
        """Get embeddings from Gemini's embedding model."""
        try:
            embedding = genai.embed_content(
                model=model,
                content=text,
                task_type="retrieval_document"
            )
            return embedding['embedding']
        except Exception as e:
            print(f"Error getting embedding: {e}")
            return None

    @staticmethod
    def distances_from_embeddings(
        query_embedding: List[float],
        embeddings: List[List[float]],
        distance_metric: str = "cosine"
    ) -> List[float]:
        """Calculate distances between a query embedding and a list of embeddings."""
        query_embedding = np.array(query_embedding).reshape(1, -1)
        embeddings = np.array(embeddings)

        if distance_metric == "cosine":
            similarities = cosine_similarity(query_embedding, embeddings)
            distances = 1 - similarities[0]
        else:
            raise ValueError(f"Unsupported distance metric: {distance_metric}")

        return distances.tolist()

    @staticmethod
    def indices_of_nearest_neighbors_from_distances(distances: List[float]) -> List[int]:
        """Get indices of nearest neighbors from distances."""
        return sorted(range(len(distances)), key=lambda k: distances[k])

In [None]:
class RecommendationSystem:
    def __init__(self, embedding_cache_path: str = "recommendations_embeddings_cache.pkl"):
        self.embedding_cache_path = embedding_cache_path
        self.embedding_cache = self._load_cache()
        self.utils = EmbeddingUtils()

    def _load_cache(self):
        """Load embedding cache from disk or create new one."""
        try:
            with open(self.embedding_cache_path, 'rb') as f:
                return pickle.load(f)
        except FileNotFoundError:
            return {}

    def _save_cache(self):
        """Save embedding cache to disk."""
        with open(self.embedding_cache_path, 'wb') as f:
            pickle.dump(self.embedding_cache, f)

    def get_embedding_with_cache(self, text: str, model: str = "embedding-001") -> List[float]:
        """Get embedding for text, using cache if available."""
        cache_key = (text, model)
        if cache_key not in self.embedding_cache:
            self.embedding_cache[cache_key] = self.utils.get_embedding(text, model)
            self._save_cache()
        return self.embedding_cache[cache_key]

    def get_recommendations(
        self,
        strings: List[str],
        index_of_source_string: int,
        k_nearest_neighbors: int = 1,
        model: str = "embedding-001"
    ) -> tuple[List[int], List[float]]:
        """Get k nearest neighbors for a source string."""
        # Get embeddings for all strings
        embeddings = [
            self.get_embedding_with_cache(string, model=model)
            for string in strings
        ]

        # Get embedding for source string
        query_embedding = embeddings[index_of_source_string]

        # Calculate distances
        distances = self.utils.distances_from_embeddings(
            query_embedding,
            embeddings,
            distance_metric="cosine"
        )

        # Get indices of nearest neighbors
        indices = self.utils.indices_of_nearest_neighbors_from_distances(distances)

        return indices, distances

def print_recommendations(
    recommender: RecommendationSystem,
    strings: List[str],
    index_of_source_string: int,
    k_nearest_neighbors: int = 1
) -> None:
    """Print recommendations in a formatted way."""
    indices, distances = recommender.get_recommendations(
        strings=strings,
        index_of_source_string=index_of_source_string,
        k_nearest_neighbors=k_nearest_neighbors
    )

    # Print source string
    query_string = strings[index_of_source_string]
    print(f"Source string: {query_string}")

    # Print recommendations
    k_counter = 0
    for i in indices:
        # Skip identical matches
        if query_string == strings[i]:
            continue
        # Stop after k recommendations
        if k_counter >= k_nearest_neighbors:
            break
        k_counter += 1

        print(
            f"""
        --- Recommendation #{k_counter} (nearest neighbor {k_counter} of {k_nearest_neighbors}) ---
        String: {strings[i]}
        Distance: {distances[i]:0.3f}"""
        )

In [None]:
# Example usage:
if __name__ == "__main__":
    # Load data
    dataset_path = "data/AG_news_samples.csv"
    df = pd.read_csv(dataset_path)

    # Initialize recommender
    recommender = RecommendationSystem()

    # Get recommendations
    article_descriptions = df["description"].tolist()
    print_recommendations(
        recommender=recommender,
        strings=article_descriptions,
        index_of_source_string=0,  # First article
        k_nearest_neighbors=5
    )