# Customizing codebase embeddings with a projection matrix

## Goal

The outcome of this exercise is to learn a projection matrix that tailors embeddings for a codebase retrieval use case, and then measure the improvement in retrieval quality.

The notebook is mostly filled out, but has a series of small gaps that you will need to fill in (everywhere you see a "TODO" comment):
- Define the similarity functions (both basic and with projection matrix)
- Define a suitable loss function
- Construct examples for training from the pre-existing dataset
- Complete the training loop code
- Finish the retrieval function logic
- Evaluate the improvement in retrieval quality

## Background

A basic retrieval augmented generation (RAG) system will typically use embeddings to represent a set of documents that are to be searched over. Then the user input can also be converted to an embedding, and the system will use the dot product of the two embeddings to determine the relevance of the input to the documents in the database.

Many embedding models are "symmetric", which means that they treat user input text and documents (e.g. code snippets) in the same way. It might be preferable to calculate the embedding differently ("asymmetrically") for the user input because is is a fundamentally different type of text.

One way of doing this is to use the same embedding model, and then apply a matrix multiplication to the embedding of the user input. What we'll try to do here is find such a matrix that can improve retrieval quality.

## Environment

We recommend using a virtual environment to install the necessary packages.

```bash
python3.11 -m venv env
source env/bin/activate
```

### Install packages

```bash
pip install -r requirements.txt
```

## Setup

Here we generate a sample embedding with `sentence_transformers`

In [1]:
from openai import OpenAI
import os
from dotenv import load_dotenv
import numpy as np
import torch
import torch.nn as nn
from sentence_transformers import SentenceTransformer

# Load a pre-trained model (this will be slow the first time)
model = SentenceTransformer("all-MiniLM-L12-v2")

def embed(text):
    embedding = model.encode([text])[0]
    return torch.tensor(embedding, dtype=torch.float32)

embedding = embed("Hello world")
dim = len(embedding)

print(f"Embedding dimension: {dim}")
print(f"Embedding: [{embedding[0]}, {embedding[1]}, ..., {embedding[-1]}]")


  from .autonotebook import tqdm as notebook_tqdm


Embedding dimension: 384
Embedding: [-0.07597316801548004, -0.0052619753405451775, ..., 0.03495463356375694]


## Similarity

First, we'll define our definition of similarity. This can be calculated using a dot product between two embeddings. For example, if we were trying to find the similarity between a user input $x_i$ and a code snippet $x_c$, then the similarity would be

$$h(x_i, x_c) = e(x_i) \cdot e(x_c)$$

Fill out the function below:

In [2]:
import math

# Define the similarity function using torch
def similarity(x_i, x_c):
    a = embed(x_i)
    b = embed(x_c)
    return a @ b
    #return torch.exp(a) @ torch.exp(b)
    raise NotImplementedError

# Calculate the similarity between two strings
x_i = "Where in the codebase do we do auth?"
x_c_1 = "```python\n# Authentication\ndef authenticate(username, password):\n    # Code to authenticate the user\n```"
x_c_2 = "function sum(a, b) {\n    return a + b;\n}"

similarity1 = similarity(x_i, x_c_1)
similarity2 = similarity(x_i, x_c_2)
print(f"Similarity 1: {similarity1}")
print(f"Similarity 2: {similarity2}")

Similarity 1: 0.4509845972061157
Similarity 2: 0.03275974839925766


## Similarity with projection matrix

Next, we'll calculate similarity using the projection matrix

$$h_\theta(x_i, x_c) = e(x_c) \theta e(x_i)$$

Fill in the function below:

In [3]:
def similarity_with_projection(x_i, x_c, P):
    a = embed(x_i)
    b = embed(x_c)
    return (a @ P) @ b
    raise NotImplementedError

# Generate a dim by dim random matrix
P_random = torch.randn(dim, dim, dtype=torch.float32)
print(P_random)

# Calculate the similarity with the random projection matrix
similarity_with_projection1 = similarity_with_projection(x_i, x_c_1, P_random)
similarity_with_projection2 = similarity_with_projection(x_i, x_c_2, P_random)
print(f"Similarity with projection 1: {similarity_with_projection1}")
print(f"Similarity with projection 2: {similarity_with_projection2}")

tensor([[ 0.7035, -0.0096,  0.5632,  ...,  0.0893, -1.2032,  1.9034],
        [ 0.5609, -0.4919,  1.3014,  ...,  0.3438,  0.5613,  1.3412],
        [ 1.1023, -1.4101, -0.6779,  ...,  0.6688,  0.3065, -0.0364],
        ...,
        [-1.9725, -0.3642,  1.3972,  ..., -0.6769, -0.5275, -0.9344],
        [-0.0266,  0.1232,  0.1249,  ...,  1.1176,  0.9145, -0.6113],
        [ 0.5806,  1.1663,  0.1401,  ...,  0.1148,  0.1335, -1.9762]])
Similarity with projection 1: 1.1331226825714111
Similarity with projection 2: 1.5683616399765015


## Load dataset

To train and test a matrix that is more helpful than the random one above, we will use a pre-existing dataset, which includes a list of (question, relevant code snippets) pairs, which happen to have been generated by a language model.

In [4]:
# Load the dataset from XML file (dataset.xml)

import xml.etree.ElementTree as ET
from dataclasses import dataclass
from typing import List

@dataclass 
class Example:
    user_input: str
    snippets: List[str]

class DatasetParser:
    def __init__(self, xml_file: str):
        self.tree = ET.parse(xml_file)
        self.root = self.tree.getroot()

    def parse(self) -> List[Example]:
        examples = []
        
        for example in self.root.findall('example'):
            user_input = example.find('user_input').text
            snippets_list = []
            
            for snippet in example.find('snippets').findall('snippet'):
                # Extract code and filename from the snippet text
                snippet_text = snippet.text.strip()
                
                # Parse the filename from the code block header
                first_line = snippet_text.split('\n')[0]
                filename = first_line.split(' ')[1] if len(first_line.split(' ')) > 1 else None
                
                # Remove the code block markers and get just the code
                code_lines = snippet_text.split('\n')[1:-1]
                code = '\n'.join(code_lines)
                
                snippets_list.append(code)
                
            examples.append(Example(
                user_input=user_input,
                snippets=snippets_list
            ))
            
        return examples


parser = DatasetParser('dataset.xml')
dataset = parser.parse()


## Construct examples

Convert the dataset into a set of examples that can be used to train the projection matrix. These should include both examples of input/snippet pairs where the snippet is relevant, and pairs where the snippet is not relevant.

In [5]:
# Next, you should generate a list of positive and negative pairs from the dataset
# These will be used to train the matrix

# TODO: Create example pairs from the dataset

# list of tuples (user input, code snippet, 1 if snippet is relevant to user input else 0)
example_pairs = []
threshold = 0.7


for example in dataset:
    for snippet in example.snippets:
        similarity_score = similarity(example.user_input, snippet).item()
        if similarity_score > threshold:
            example_pairs.append((example.user_input, snippet, 1))
        else:
            example_pairs.append((example.user_input, snippet, 0))


In [6]:
# Here we split the example pairs into training and validation sets
np.random.shuffle(example_pairs)
split_index = int(0.8 * len(example_pairs))
train_pairs = example_pairs[:split_index]
val_pairs = example_pairs[split_index:]

print(f"Number of training pairs: {len(train_pairs)}")
print(f"Number of validation pairs: {len(val_pairs)}")

Number of training pairs: 16
Number of validation pairs: 4


## Define a loss function

With a model to calculate similarity, and a dataset of positive and negative examples, we're almost ready to train. The last thing we need is a loss function. Design a loss function that is suitable for this use case.

In [7]:
def loss_func(predictions, targets):
    return nn.BCEWithLogitsLoss()(predictions, targets)

## Train the projection matrix

The entire training loop has been set up, except for a couple of lines to calculate the prediction given an example pair and to get $y$, which will then be used together to calculate the loss.

In [8]:
import torch.optim as optim

# Initialize the projection matrix P
P = torch.randn(
    dim, dim, requires_grad=True
)

# Set hyperparameters
lr = 0.001
num_epochs = 25
optimizer = optim.Adam([P], lr=lr)
epochs, types, losses, accuracies, matrices = [], [], [], [], []

for epoch in range(num_epochs):
    # Reset gradients
    optimizer.zero_grad()

    # Iterate through training pairs
    for pair in train_pairs:
        # TODO: Get `prediction` and `y` to pass to `loss_func` 
        x_i, x_c, y = pair
        prediction = similarity_with_projection(x_i, x_c, P)
        y = torch.tensor([y], dtype=torch.float32)

        loss = loss_func(prediction.unsqueeze(0), y)
        loss.backward()
    
    # Update weights using Adam optimizer
    optimizer.step()

    # Calculate validation loss
    val_loss = 0
    for pair in val_pairs:
        # TODO: Get `prediction` and `y` to pass to `loss_func`
        x_i, x_c, y = pair
        prediction = similarity_with_projection(x_i, x_c, P)
        y = torch.tensor([y], dtype=torch.float32)
        
        val_loss += loss_func(prediction.unsqueeze(0), y)

    print(f"Epoch {epoch}/{num_epochs}: validation loss: {val_loss.item() / len(val_pairs)}")


Epoch 0/25: validation loss: 0.6012709140777588
Epoch 1/25: validation loss: 0.5857683420181274
Epoch 2/25: validation loss: 0.5706733465194702
Epoch 3/25: validation loss: 0.5560002326965332
Epoch 4/25: validation loss: 0.5417613983154297
Epoch 5/25: validation loss: 0.5279673337936401
Epoch 6/25: validation loss: 0.5146280527114868
Epoch 7/25: validation loss: 0.5017523169517517
Epoch 8/25: validation loss: 0.4893467128276825
Epoch 9/25: validation loss: 0.47741591930389404
Epoch 10/25: validation loss: 0.46596160531044006
Epoch 11/25: validation loss: 0.454982191324234
Epoch 12/25: validation loss: 0.4444738030433655
Epoch 13/25: validation loss: 0.4344303011894226
Epoch 14/25: validation loss: 0.42484262585639954
Epoch 15/25: validation loss: 0.4157010316848755
Epoch 16/25: validation loss: 0.4069927930831909
Epoch 17/25: validation loss: 0.39870527386665344
Epoch 18/25: validation loss: 0.39082393050193787
Epoch 19/25: validation loss: 0.3833339810371399
Epoch 20/25: validation lo

## Retrieval strategy

We now have a potentially improved embedding model, but need to use it for retrieval. Finish the retrieval function, which will take a user input and return relevant code snippets from the full list. Note: a vector database is not necessary.

In [10]:

all_snippets = []

for example in dataset:
    for snippet in example.snippets:
        all_snippets.append(snippet)

# Use similarity search with the embeddings model to retrieve relevant snippets
def retrieve_relevant_snippets(user_input: str):
    user_embedding = embed(user_input)
    projected_user_embedding = user_embedding @ P

    similarities = []

    for snippet in all_snippets:
        snippet_embedding = embed(snippet)
        similarity = (projected_user_embedding @ snippet_embedding).item()
        similarities.append((snippet, similarity))
    return similarities
    raise NotImplementedError.add_note

## Evaluate the new retrieval strategy

If the loss was lower by the last epoch, then we know that we improved the similarity function (at least for the validation set), but we still need a way of evaluating the retrieval strategy as a whole.

Your last task is to design an evaluation metric suitable for codebase retrieval, which we can run over the examples in the above dataset. The result of the evaluation should be a single number that attempts to represent the quality of the retrieval strategy.

In [11]:
retrieve_relevant_snippets("Reset password")

[('def initiate_password_reset(email):\n    token = generate_reset_token()\n    send_reset_email(email, token)\n    store_reset_token(email, token, expiry=24*hours)\n    return True\n\ndef validate_reset_token(token, new_password):\n    if is_token_valid(token):\n        user = get_user_by_token(token)\n        update_password(user, new_password)\n        invalidate_token(token)\n        return True\n    return False',
  -1.4833674430847168),
 ('export class UserService {\n  async register(userData: RegisterDTO): Promise<User> {\n    const existingUser = await this.userRepo.findByEmail(userData.email);\n    if (existingUser) {\n      throw new DuplicateUserError();\n    }\n    \n    const hashedPassword = await bcrypt.hash(userData.password);\n    return this.userRepo.create({\n      ...userData,\n      password: hashedPassword\n    });\n  }\n}',
  -0.5153960585594177),
 ("class PaymentProcessor {\n  async processPayment(amount: number, paymentMethod: PaymentMethod): Promise<PaymentRes

In [12]:
from typing import List, Tuple

def average_precision(relevant_items: List[str], retrieved_items: List[Tuple[str, float]]) -> float:
    relevant_set = set(relevant_items)
    precision_sum = 0
    relevant_count = 0

    for i, (item, _) in enumerate(retrieved_items, 1):
        if item in relevant_set:
            relevant_count += 1
            precision_sum += relevant_count / i

    return precision_sum / len(relevant_items) if relevant_items else 0

def evaluate_retrieval_strategy(retrieval_strategy):
    total_ap = 0
    num_queries = 0

    for example in dataset:
        retrieved = retrieval_strategy(example.user_input)
        ap = average_precision(example.snippets, retrieved)
        total_ap += ap
        num_queries += 1

    map_score = total_ap / num_queries if num_queries > 0 else 0
    return f"Mean Average Precision (MAP): {map_score:.4f}"

result = evaluate_retrieval_strategy(retrieve_relevant_snippets)
print(result)

Mean Average Precision (MAP): 0.1799
