# Maximal Marginal Relevance

In [1]:
%%capture --no-stderr
!pip install transformers
!pip install torch

In [2]:
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np

Imagine you want to find the most relevant news articles about `London` from a dataset. Here’s an example dataset containing news article titles:

In [3]:
article_titles = [
    # Culture
    "The Revival of Ancient Traditions in Modern Society",
    "Exploring the Intersection of Art and Technology",
    "A Deep Dive into Indigenous Music Around the World",
    "How Street Art is Shaping Urban Culture",
    "Cultural Festivals You Can't Miss This Year",
    "The Impact of Globalization on Local Cultures",
    "Art Exhibitions That Will Transform Your Perspective",
    "The Evolution of Theatre in London",
    "The Influence of Eastern Philosophy in Western Culture",
    "Exploring London's Cultural Scene",

    # Weather
    "Understanding the Science Behind Extreme Weather Events",
    "How Climate Change is Affecting Global Weather Patterns",
    "Preparing for Hurricane Season: What You Need to Know",
    "The Future of Weather Forecasting: Innovations and Challenges",
    "Heatwaves: Causes, Effects, and Mitigation Strategies",
    "The Impact of El Niño and La Niña on Global Weather",
    "Winter Storms: Preparing for the Unexpected",
    "How Urbanization Affects Local Weather Patterns",
    "Weather Patterns Affecting London This Summer",
    "How London's Weather Has Changed Over the Decades",

    # World News
    "Global Leaders Convene to Discuss Climate Action",
    "The Economic Impacts of the Latest Trade Agreements",
    "Technological Advancements in Developing Nations",
    "Elections Around the World: Key Outcomes and Implications",
    "The Role of Social Media in Modern Revolutions",
    "Global Health Initiatives: Progress and Challenges",
    "Diplomatic Tensions and Their Global Ramifications",
    "The Growing Influence of Social Media on Photography",
    "London's Role in Global Climate Talks",
    "Brexit and Its Impact on London's Economy",

    # Programming Languages
    "Top 10 Programming Languages to Learn in 2024",
    "How Python Became the Go-To Language for Data Science",
    "The Growing Popularity of Rust in Systems Programming",
    "Comparing Functional and Object-Oriented Programming Paradigms",
    "The Impact of Open Source on Programming Language Development",
    "The Role of Swift in Apple's Ecosystem",
    "Emerging Programming Languages to Keep an Eye On",
    "Python Tips for Mastering Data Science",
    "London's Tech Scene and Programming Trends",
    "How London's Startups are Using AI",

    # Photography News
    "The Intersection of Photography and Virtual Reality",
    "How Drones are Revolutionizing Aerial Photography",
    "The Best New Cameras and Lenses of the Year",
    "Exploring the World of Underwater Photography",
    "The Art of Portrait Photography: Techniques and Tips",
    "The Role of Post-Processing in Modern Photography",
    "The Intersection of Photography and Virtual Reality",
    "How to Build a Professional Photography Portfolio",
    "Exploring London's Best Photography Spots",
    "The Best Photography Exhibitions in London",

    # Things to Do in London
    "Historic Landmarks You Can't Miss in London",
    "Family-Friendly Activities in London",
    "Family-Friendly Events in London",
    "The Best Parks and Green Spaces in London",
    "Unique Shopping Experiences in London",
    "Day Trips from London: Exploring the Countryside",
    "Cultural Festivals in London This Year",
    "Top Photography Spots in London",
    "London's Tech Scene and Programming Trends",
    "Best Photo Spots in London"
]

To approach this problem, one method is to use a [similarity measure](https://en.wikipedia.org/wiki/Similarity_measure) to rank documents by their relevance to a query. A common measure is [cosine similarity](https://en.wikipedia.org/wiki/Cosine_similarity) applied to [word embeddings](https://en.wikipedia.org/wiki/Word_embedding) of the documents and the query. The algorithm follows these steps:

1. Create word embedding representation of the documents and the query.
2. Compute cosine similarity between the representations of the query and each document.
3. Select `N` documents most similar to the query.

We can use the [transformers](https://huggingface.co/docs/transformers/en/index) library to leverage a pretrained language model (such as the classic `bert-base-uncased`) to generate text embeddings.

This function generates embeddings by averaging the hidden states of the model’s last layer for each input text.

In [4]:
def get_embeddings(text_list, model_name="bert-base-uncased"):
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModel.from_pretrained(model_name)

    inputs = tokenizer(text_list, return_tensors="pt", padding=True, truncation=True)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    embeddings = outputs.last_hidden_state.mean(dim=1)
    embeddings = embeddings.numpy()
    
    return embeddings

Cosine similarity is calculated using the formula $\text{cosine\_similarity}(A, B) = \frac{A * B}{\|A\| \|B\|}$:

In [5]:
def cosine_similarity(vector1, vector2):
    dot_product = np.dot(vector1, vector2)
    norm_vector1 = np.linalg.norm(vector1)
    norm_vector2 = np.linalg.norm(vector2)
    return dot_product / (norm_vector1 * norm_vector2)

To compute cosine similarities between the query and each document, we can combine the query and documents into a single list and then compute similarities between each pair of embeddings (that would also get us similarities between each document, but we would need that later in the code).

In [6]:
def compute_similarities(documents, query):
    quiery_and_documents = [query] + documents
    embeddings = get_embeddings(quiery_and_documents)
    num_embeddings = embeddings.shape[0]
    similarities = np.zeros((num_embeddings, num_embeddings))

    for i in range(num_embeddings):
        for j in range(num_embeddings):
            similarities[i, j] = cosine_similarity(embeddings[i], embeddings[j])

    return similarities

article_similarities = compute_similarities(article_titles, query="London")

Next, rank the documents based on their similarity to the query:

In [7]:
def similarity_relevance(similarities):
    return sorted(range(len(similarities) - 1),
                  key=lambda i: similarities[0, i+1],
                  reverse=True)

In [8]:
def print_selected(ranking, similarities, entities):
    for i in ranking:
        print(f"{similarities[0, i+1]}\t{entities[i]}")

Here are the top 7 relevant news articles:

In [9]:
article_similarity_order = similarity_relevance(article_similarities)
print_selected(article_similarity_order[:7], article_similarities, article_titles)

0.7103888392448425	Family-Friendly Activities in London
0.7099422812461853	The Best Photography Exhibitions in London
0.7076156735420227	Unique Shopping Experiences in London
0.7006320357322693	Family-Friendly Events in London
0.6981644034385681	Best Photo Spots in London
0.6938719153404236	Top Photography Spots in London
0.6879022121429443	Exploring London's Cultural Scene


These articles are relevant to London, but some are duplicates:
- "Family-Friendly Activities in London" and "Family-Friendly Events in London"
- "Best Photo Spots in London" and "Top Photography Spots in London"

In many cases, retrieving both relevant and diverse documents is essential. For instance, when using [Retrieval-Augmented Generation](https://www.promptingguide.ai/techniques/rag) (RAG) in LLMs, the context window is limited, so we prefer selecting text snippets that are relevant but not duplicates. The [LangChain](https://www.langchain.com/) framework supports RAG and provides the [Maximal Marginal Relevance](https://aclanthology.org/X98-1025/) (MMR) technique.

A document has **high marginal relevance** if it is both *relevant* to the query and contains *minimal similarity to previously selected documents*. **MMR** is defined as:

$$
MMR \overset{def}{=} \underset{D_i \in R \setminus S}{argmax}
      [ \lambda * Sim_1(D_i,Q) -
        (1-\lambda) * \underset{D_j \in S}{max}(Sim_2(D_i, D_j)) ]
$$

where:
- C - document collection
- Q - query
- R - ranked list of documents retrieved by the IR system ($ R \subseteq C $)
- S - subset of documents in R already provided to the user ($ S \subseteq C $)
- R \\ S - subset of documents not yet offered to the user
- $ \lambda $ - hyperparameter to prefer more relevant or more diverse documents

Let's implement this technique. First, we select the most relevant document. Then we iteratively select the document that gives the maximum MMR score (most relevant one and most dissimilar to the documents that we have already selected), until we select the requested number of documents.

We'll reuse the previously computed similarity matrix and now use similarities between each pair of documents. The document-document similarity metric can differ from the query-document similarity, but for simplicity, we'll use cosine similarity as well.

In [10]:
def maximal_marginal_relevance(similarities, num_to_select, lambda_param):
    if similarities.shape[0] <= 1 or num_to_select <= 0:
        return []
    
    most_similar = np.argmax(similarities[0, 1:])

    selected = [most_similar]
    candidates = set(range(len(similarities) - 1))
    candidates.remove(most_similar)

    while (len(selected) < num_to_select):
        if not candidates:
            break

        mmr_scores = {}
        for i in candidates:
            mmr_scores[i] = (lambda_param * similarities[i+1, 0] -
                (1 - lambda_param) * max([similarities[i+1, j+1] for j in selected]))

        next_best = max(mmr_scores, key=mmr_scores.get)
        selected.append(next_best)
        candidates.remove(next_best)
    return selected

In [11]:
import unittest

class TestMaximalMarginalRelevance(unittest.TestCase):
    def test_basic_case(self):
        similarities = np.array([
            [1, 0.8, 0.6],
            [0.8, 1, 0.5],
            [0.6, 0.5, 1]
        ])
        result = maximal_marginal_relevance(similarities, 2, 0.5)
        self.assertEqual(result, [0, 1])

    def test_single_selection(self):
        similarities = np.array([
            [1, 0.8, 0.9],
            [0.8, 1, 0.7],
            [0.9, 0.7, 1]
        ])
        result = maximal_marginal_relevance(similarities, 1, 0.5)
        self.assertEqual(result, [1])

    def test_all_selection(self):
        similarities = np.array([
            [1, 0.8, 0.6, 0.4],
            [0.8, 1, 0.5, 0.2],
            [0.6, 0.5, 1, 0.1],
            [0.4, 0.1, 0.2, 1]
        ])
        result = maximal_marginal_relevance(similarities, 3, 0.5)
        self.assertEqual(result, [0, 2, 1])

    def test_lambda_param(self):
        similarities = np.array([
            [1, 0.8, 0.6, 0.4],
            [0.8, 1, 0.5, 0.2],
            [0.6, 0.5, 1, 0.1],
            [0.4, 0.1, 0.2, 1]
        ])
        result = maximal_marginal_relevance(similarities, 3, 0.9)
        self.assertEqual(result, [0, 1, 2])

    def test_empty_selection(self):
        similarities = np.array([
            [1, 0.8, 0.6],
            [0.8, 1, 0.5],
            [0.6, 0.5, 1]
        ])
        result = maximal_marginal_relevance(similarities, 0, 0.5)
        self.assertEqual(result, [])

    def test_no_similarities(self):
        similarities = np.empty(0)
        result = maximal_marginal_relevance(similarities, 2, 0.5)
        self.assertEqual(result, [])

unittest.main(argv=[''], exit=False)

......
----------------------------------------------------------------------
Ran 6 tests in 0.003s

OK


<unittest.main.TestProgram at 0x15724ca50>

Let's select the top 7 relevant news articles using the MMR technique:

In [12]:
article_mmr_order = maximal_marginal_relevance(article_similarities,
                                               num_to_select = 7,
                                               lambda_param = .5)
print_selected(article_mmr_order, article_similarities, article_titles)

0.7103888392448425	Family-Friendly Activities in London
0.685788631439209	The Evolution of Theatre in London
0.6981644034385681	Best Photo Spots in London
0.6072949767112732	Weather Patterns Affecting London This Summer
0.5159177780151367	Python Tips for Mastering Data Science
0.6396927237510681	The Best Parks and Green Spaces in London
0.7076156735420227	Unique Shopping Experiences in London


The new set of articles does not contain duplicates. However, there is an article not quite related to our query: "Python Tips for Mastering Data Science". This issue can be mitigated by selecting a better $ \lambda $ value.

Referring to the MMR formula again:

- $ \lambda = 1 $: Computes incrementally the standard relevance-ranked list
- $ \lambda = 0 $: Computes a maximal diversity ranking among documents in R
- $ \lambda \in [0,1] $: Optimizes a linear combination of both criteria

By setting a higher value of $ \lambda = .7 $:

In [13]:
article_mmr_order_07 = maximal_marginal_relevance(article_similarities,
                                                  num_to_select = 7,
                                                  lambda_param = .7)
print_selected(article_mmr_order_07, article_similarities, article_titles)

0.7103888392448425	Family-Friendly Activities in London
0.685788631439209	The Evolution of Theatre in London
0.7099422812461853	The Best Photography Exhibitions in London
0.7076156735420227	Unique Shopping Experiences in London
0.6879022121429443	Exploring London's Cultural Scene
0.6938719153404236	Top Photography Spots in London
0.6432080268859863	Day Trips from London: Exploring the Countryside


Now, all articles are related to `London`, and there are no duplicates. Increasing $ \lambda $ further to .8, then we get:

In [14]:
article_mmr_order_08 = maximal_marginal_relevance(article_similarities,
                                                  num_to_select = 7,
                                                  lambda_param = .8)
print_selected(article_mmr_order_08, article_similarities, article_titles)

0.7103888392448425	Family-Friendly Activities in London
0.7099422812461853	The Best Photography Exhibitions in London
0.7076156735420227	Unique Shopping Experiences in London
0.685788631439209	The Evolution of Theatre in London
0.6879022121429443	Exploring London's Cultural Scene
0.6938719153404236	Top Photography Spots in London
0.6981644034385681	Best Photo Spots in London


Again, all articles are related to `London`, but we have a duplicate pair: "Top Photography Spots in London" and "Best Photo Spots in London". For this example, $ \lambda = .7 $ works well.