## Structure


## Setup

In [None]:
# !pip install -r requirements.txt
# !pip install sentence-transformers
# !pip install mteb
# !pip install beir
# !pip install datasets
# !pip install wandb

## Matryoshka-Adaptor Implementation

### Architecture

In [None]:
import torch
import torch.nn.functional as F

# Define MatryoshkaAdaptor module - a simple MLP with skip connection
class MatryoshkaAdaptor(torch.nn.Module):
    """
    A PyTorch neural network module that adapts the output of an embedding model
    into a desired output dimension using two linear transformations with a ReLU activation in between.
    Includes a skip connection from input to output.
    """
    def __init__(self, input_output_dim, hidden_dim):
        """
        Initializes the MatryoshkaAdaptor module.
        
        Args:
            input_output_dim: An integer representing the input and output dimension of the module which are equal.
            hidden_dim: An integer representing the hidden dimension of the module.
            
        Returns:
            None
        """
        super(MatryoshkaAdaptor, self).__init__()
        # First linear layer to transform the input dimension to a hidden dimension
        self.linear1 = torch.nn.Linear(input_output_dim, hidden_dim)
        # Second linear layer to transform the hidden dimension to the output dimension which is same as input dimension
        self.linear2 = torch.nn.Linear(hidden_dim, input_output_dim)
        # Activation function to introduce non-linearity
        self.activation = torch.nn.ReLU()

    def forward(self, embedding):
        """
        Forward pass of the MatryoshkaAdaptor module.

        Args:
            embedding: A torch.Tensor of shape (batch_size, input_output_dim) representing the input embeddings.

        Returns:
            output: A torch.Tensor of shape (batch_size, input_output_dim) representing the matryoshka embeddings.
        """
        # Apply the first linear transformation followed by the activation function
        hidden_embedding = self.activation(self.linear1(embedding))
        
        # Apply the second linear transformation to get the final adapted embedding
        adapted_embedding = self.linear2(hidden_embedding)
        
        # Add the skip connection by adding the original embedding to the adapted embedding
        mat_embedding = adapted_embedding + embedding

        return mat_embedding

### Loss Functions

In [None]:
# Equation 1 in paper
def pairwise_similarity_loss(original_embeddings, matryoshka_embeddings):
    """
    Computes the pairwise similarity loss between original embeddings and matryoshka embeddings.
    
    Args:
        original_embeddings: A tensor of shape (batch_size, embedding_dim) representing the original embeddings.
        matryoshka_embeddings: A tensor of shape (batch_size, mat_embedding_dim) representing the matryoshka embeddings.
        
    Returns:
        loss: A scalar tensor representing the mean pairwise similarity loss.
    """
    
    # Normalize the embeddings along the embedding dimension to get the cosine similarity
    normalized_original_embeddings = F.normalize(original_embeddings, p=2, dim=1)
    normalized_matryoshka_embeddings = F.normalize(matryoshka_embeddings, p=2, dim=1)
    
    # Compute the cosine similarity matrices
    original_similarity_matrix = torch.matmul(normalized_original_embeddings, normalized_original_embeddings.T)
    matryoshka_similarity_matrix = torch.matmul(normalized_matryoshka_embeddings, normalized_matryoshka_embeddings.T)
    
    # Get the indices of the upper triangle of the matrices, excluding the diagonal
    batch_size = original_embeddings.size(0)
    i, j = torch.triu_indices(batch_size, batch_size, offset=1)
    
    # Compute the pairwise cosine similarities
    original_pairwise_similarities = original_similarity_matrix[i, j]
    matryoshka_pairwise_similarities = matryoshka_similarity_matrix[i, j]
    
    # Compute the absolute difference between corresponding pairwise similarities
    similarity_differences = torch.abs(original_pairwise_similarities - matryoshka_pairwise_similarities)
    
    # Sum up all the absolute differences to produce the final loss
    loss = torch.sum(similarity_differences)
    
    return loss

# Equation 2 in paper
def topk_similarity_loss(original_embeddings, matryoshka_embeddings, k=5):
    """
    Computes the top-k similarity loss between original embeddings and matryoshka embeddings.
    
    Args:
        original_embeddings: A tensor of shape (batch_size, embedding_dim) representing the original embeddings.
        matryoshka_embeddings: A tensor of shape (batch_size, mat_embedding_dim) representing the matryoshka embeddings.
        k: The number of top similarities to consider (default is 5).
        
    Returns:
        loss: A scalar tensor representing the top-k similarity loss.
    """
    
    # Normalize the embeddings along the embedding dimension to get the cosine similarity
    normalized_original_embeddings = F.normalize(original_embeddings, p=2, dim=1)
    normalized_matryoshka_embeddings = F.normalize(matryoshka_embeddings, p=2, dim=1)
    
    # Compute the cosine similarity matrices
    original_similarity_matrix = torch.matmul(normalized_original_embeddings, normalized_original_embeddings.T)
    matryoshka_similarity_matrix = torch.matmul(normalized_matryoshka_embeddings, normalized_matryoshka_embeddings.T)
    
    # Exclude self-similarity by setting the diagonal to a very low value
    batch_size = original_embeddings.size(0)
    original_similarity_matrix.fill_diagonal_(-float('inf'))
    matryoshka_similarity_matrix.fill_diagonal_(-float('inf'))
    
    # For each embedding, get the top-k similarities and their corresponding indices
    original_topk_values, _ = torch.topk(original_similarity_matrix, k, dim=1)
    matryoshka_topk_values, _ = torch.topk(matryoshka_similarity_matrix, k, dim=1)
    
    # Compute the absolute difference between the top-k similarities
    similarity_differences = torch.abs(original_topk_values - matryoshka_topk_values)
    
    # Sum up all the absolute differences to produce the final loss
    loss = torch.sum(similarity_differences)
    
    return loss


# Equation 3 in paper
def regularization_loss(original_embeddings, matryoshka_embeddings, alpha=1.0):
    """
    Computes the regularization loss to ensure the matryoshka embeddings do not deviate
    significantly from the original embeddings.
    
    Args:
        original_embeddings: A tensor of shape (batch_size, embedding_dim) representing the original embeddings.
        matryoshka_embeddings: A tensor of shape (batch_size, embedding_dim) representing the matryoshka embeddings.
        alpha: A regularization coefficient that controls the weight of the regularization term.
        
    Returns:
        loss: A scalar tensor representing the regularization loss.
    """
    # Compute the difference between original and matryoshka embeddings
    diff = original_embeddings - matryoshka_embeddings
    
    # Compute the L2 norm of the difference
    loss = torch.norm(diff, p=2, dim=1)
    
    # Return the mean loss over the batch, scaled by alpha
    return alpha * loss.mean()


# Equation 4 in paper
def unsupervised_objective_fn_loss(original_embeddings, matryoshka_embeddings, 
                                   k=5, alpha=1.0, beta=1.0):
    """
    Computes the overall unsupervised objective function loss as a combination of top-k similarity loss,
    alpha-scaled pairwise similarity loss, and beta-scaled regularization loss.
    
    Args:
        original_embeddings: A tensor of shape (batch_size, embedding_dim) representing the original embeddings.
        matryoshka_embeddings: A tensor of shape (batch_size, mat_embedding_dim) representing the matryoshka embeddings.
        k: The number of top similar embeddings to consider for the top-k similarity loss.
        alpha: A scaling factor for the pairwise similarity loss.
        beta: A scaling factor for the regularization loss.
        
    Returns:
        total_loss: A scalar tensor representing the combined unsupervised objective function loss.
    """
    # Compute the individual loss components
    topk_loss = topk_similarity_loss(original_embeddings, matryoshka_embeddings, k)
    pairwise_loss = pairwise_similarity_loss(original_embeddings, matryoshka_embeddings)
    reg_loss = regularization_loss(original_embeddings, matryoshka_embeddings, beta)
    
    # Combine the losses with the given scaling factors
    total_loss = topk_loss + alpha * pairwise_loss + beta * reg_loss
    
    return total_loss


# Equation 5 in paper


## Training of Adaptor

### Unsupervised Implementation

#### Prepare Datasets + Dataloaders

We will use BEIR's NFCorpus, and train on the corpus only in an unsupervised manner, as detailed in the paper.

In [None]:
from datasets import load_dataset
from torch.utils.data import DataLoader, random_split
import torch

# Load the dataset
corpus_ds = load_dataset("BeIR/nfcorpus", "corpus")

# Access the 'corpus' dataset
dataset = corpus_ds['corpus']['text']

# Define the split sizes
train_size = int(0.7 * len(dataset))
test_size = len(dataset) - train_size

# Split the dataset
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Define batch size
batch_size = 32  # Adjust this as needed

# Create DataLoader for train and test datasets
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)


#### Prepare Embedding Model and Matryoshka-Adaptor

In [None]:
from sentence_transformers import SentenceTransformer

# Embedding Model
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

# Matryoshka-Adaptor
input_output_dim = model.get_sentence_embedding_dimension() # Embedding dimension for model (d in paper)
hidden_dim = input_output_dim # Let hidden layer dimension equal the embedding model dimension
mat_adaptor = MatryoshkaAdaptor(input_output_dim, hidden_dim)

#### Train Matryoshka-Adaptor

In [None]:
import torch
from torch.optim import Adam
import wandb

def train(model, mat_adaptor, train_loader, loss_fn, kwargs):
    """
    Trains the MatryoshkaAdaptor module using the provided training data.

    Args:
        model: A SentenceTransformer model to generate embeddings.
        mat_adaptor: A MatryoshkaAdaptor module to adapt the embeddings.
        train_loader: A DataLoader object for the training dataset.
        loss_fn: A loss function to compute the loss between original and matryoshka embeddings.
        kwargs: A dictionary containing hyperparameters for training.

    Returns:
        None
    """
    
    # Initialize Weights & Biases
    wandb.init(project="matryoshka-training", config=kwargs)
    config = wandb.config

    # Unpack the hyperparameters
    epochs = config.get('epochs', 5)
    lr = config.get('lr', 1e-3)
    k = config.get('k', 5) # Top-k similarity loss
    m = config.get('m', 128) # Matryoshka embedding dimension
    alpha = config.get('alpha', 1.0) # Pairwise similarity loss scaling factor (alpha in paper)
    beta = config.get('beta', 1.0)  # Regularization loss scaling factor (beta in paper)

    # Define an optimizer for the MatryoshkaAdaptor parameters
    optimizer = Adam(mat_adaptor.parameters(), lr=lr)
    
    # Set MatryoshkaAdaptor to training mode
    mat_adaptor.train()

    for epoch in range(epochs):
        total_loss = 0
        for batch in train_loader:
            # Generate embeddings for both texts
            ori_embeddings = model.encode(batch, convert_to_tensor=True)  # model batched embeddings

            # Forward pass embedding through the MatryoshkaAdaptor
            mat_embeddings = mat_adaptor(ori_embeddings)

            # Compute loss
            loss = loss_fn(ori_embeddings, mat_embeddings, k=k, alpha=alpha, beta=beta)

            # Backpropagation
            optimizer.zero_grad()  # Clear previous gradients
            loss.backward()        # Compute gradients
            optimizer.step()        # Update weights

            total_loss += loss.item()
            
        # Calculate average loss for the epoch
        avg_loss = total_loss / len(train_loader)
        
        # Log the average loss to W&B
        wandb.log({"epoch": epoch + 1, "loss": avg_loss})
        
        # Print average loss for the epoch
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}")

    # Finish the W&B run
    wandb.finish()


In [None]:
hyperparams = {
    'epochs': 5,
    'lr': 1e-3,
    'k': 5,  # Top-k similarity loss
    'm': 128,  # Matryoshka embedding dimension
    'alpha': 1.0,  # Pairwise similarity loss scaling factor (alpha in paper)
    'beta': 1.0  # Regularization loss scaling factor (beta in paper)
}

train(model, mat_adaptor, train_dataloader, unsupervised_objective_fn_loss, hyperparams)

### Supervised Implementation

#### Prepare Datasets + Dataloaders

As before, we will use BEIR's NFCorpus, and train on the corpus-query pairs in a supervised manner, as detailed in the paper.
We need to manually download the BEIR dataset since the qrels cannot be accessed on the huggingface dataset. 

In [4]:
from beir import util, LoggingHandler
from beir.datasets.data_loader import GenericDataLoader
import torch
from torch.utils.data import Dataset, DataLoader
import logging

# Set up logging
logging.basicConfig(format='%(asctime)s - %(message)s', level=logging.INFO, handlers=[LoggingHandler()])

# Define the dataset name and the path to store it
dataset = "nfcorpus"
url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
data_path = util.download_and_unzip(url, "datasets")

# Load train and test data
train_corpus, train_queries, train_qrels = GenericDataLoader(data_path).load(split="train")
dev_corpus, dev_queries, dev_qrels = GenericDataLoader(data_path).load(split="dev")
test_corpus, test_queries, test_qrels = GenericDataLoader(data_path).load(split="test")

class BEIRDataset(Dataset):
    def __init__(self, queries, corpus, qrels):
        self.queries = queries
        self.corpus = corpus
        self.qrels = qrels
        self.query_ids = list(queries.keys())
        
    def __len__(self):
        return len(self.query_ids)
    
    def __getitem__(self, idx):
        query_id = self.query_ids[idx]
        query_text = self.queries[query_id]
        
        relevant_docs = self.qrels.get(query_id, {})
        
        # Get all document ids and scores for this query
        doc_ids = list(relevant_docs.keys())
        scores = [relevant_docs[doc_id] for doc_id in doc_ids]
        
        # Get document texts
        doc_texts = [self.corpus[doc_id] for doc_id in doc_ids]
        
        return {
            'query_id': query_id,
            'query_text': query_text,
            'doc_ids': doc_ids,
            'doc_texts': doc_texts,
            'scores': scores
        }
    

def collate_fn(batch):
    query_ids = [item['query_id'] for item in batch]
    query_texts = [item['query_text'] for item in batch]
    doc_ids = [item['doc_ids'] for item in batch]
    doc_texts = [item['doc_texts'] for item in batch]
    scores = [item['scores'] for item in batch]
    
    # Pad sequences if necessary
    max_docs = max(len(docs) for docs in doc_ids)
    
    padded_doc_ids = [docs + [''] * (max_docs - len(docs)) for docs in doc_ids]
    padded_doc_texts = [texts + [''] * (max_docs - len(texts)) for texts in doc_texts]
    padded_scores = [s + [0] * (max_docs - len(s)) for s in scores]
    
    return {
        'query_ids': query_ids,
        'query_texts': query_texts,
        'doc_ids': padded_doc_ids,
        'doc_texts': padded_doc_texts,
        'scores': torch.tensor(padded_scores)
    }

2024-08-22 23:26:10,937 - Loading Corpus...


100%|██████████| 3633/3633 [00:00<00:00, 17025.46it/s]


2024-08-22 23:26:11,197 - Loaded 3633 TRAIN Documents.
2024-08-22 23:26:11,197 - Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 partici

100%|██████████| 3633/3633 [00:00<00:00, 21010.59it/s]

2024-08-22 23:26:11,752 - Loaded 3633 DEV Documents.
2024-08-22 23:26:11,753 - Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 participa




2024-08-22 23:26:11,860 - Loaded 324 DEV Queries.
2024-08-22 23:26:11,861 - Query Example: Why Deep Fried Foods May Cause Cancer
2024-08-22 23:26:11,862 - Loading Corpus...


100%|██████████| 3633/3633 [00:00<00:00, 15800.83it/s]


2024-08-22 23:26:12,121 - Loaded 3633 TEST Documents.
2024-08-22 23:26:12,122 - Doc Example: {'text': 'Recent studies have suggested that statins, an established drug group in the prevention of cardiovascular mortality, could delay or prevent breast cancer recurrence but the effect on disease-specific mortality remains unclear. We evaluated risk of breast cancer death among statin users in a population-based cohort of breast cancer patients. The study cohort included all newly diagnosed breast cancer patients in Finland during 1995–2003 (31,236 cases), identified from the Finnish Cancer Registry. Information on statin use before and after the diagnosis was obtained from a national prescription database. We used the Cox proportional hazards regression method to estimate mortality among statin users with statin use as time-dependent variable. A total of 4,151 participants had used statins. During the median follow-up of 3.25 years after the diagnosis (range 0.08–9.0 years) 6,011 particip

In [5]:

train_dataset = BEIRDataset(train_queries, train_corpus, train_qrels)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
    
dev_dataset = BEIRDataset(dev_queries, dev_corpus, dev_qrels)
dev_dataloader = DataLoader(dev_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

test_dataset = BEIRDataset(test_queries, test_corpus, test_qrels)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)

#### Prepare Embedding Model and Matryoshka-Adaptor

In [6]:
from sentence_transformers import SentenceTransformer

# Embedding Model
model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')

# Matryoshka-Adaptor
input_output_dim = model.get_sentence_embedding_dimension() # Embedding dimension for model (d in paper)
hidden_dim = input_output_dim # Let hidden layer dimension equal the embedding model dimension
mat_adaptor = MatryoshkaAdaptor(input_output_dim, hidden_dim)

2024-08-22 23:26:31,312 - NumExpr defaulting to 8 threads.
2024-08-22 23:26:33,030 - PyTorch version 2.3.0 available.
2024-08-22 23:26:33,039 - Polars version 1.5.0 available.
2024-08-22 23:26:37,634 - Use pytorch device_name: cpu
2024-08-22 23:26:37,634 - Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2


NameError: name 'MatryoshkaAdaptor' is not defined

#### Train Matryoshka-Adaptor

In [None]:
hyperparams = {
    'epochs': 5,
    'lr': 1e-3,
    'k': 5,  # Top-k similarity loss
    'm': 128,  # Matryoshka embedding dimension
    'alpha': 1.0,  # Pairwise similarity loss scaling factor (alpha in paper)
    'beta': 1.0  # Regularization loss scaling factor (beta in paper)
}

train(model, mat_adaptor, train_dataloader, supervised_objective_fn_loss, hyperparams)

## BEIR Evaluation using NFCorpus

### Unmodified Model Performance for all-MiniLM-L6-v2

In [None]:
import mteb
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
model_name = 'all-MiniLM-L6-v2_BASE'

# Define the BEIR tasks you want to evaluate on
tasks = mteb.get_tasks(tasks=["NFCorpus"])

# Evaluate the model on the benchmark
evaluation = mteb.MTEB(tasks=tasks)
results = evaluation.run(model, output_folder=f"results/{model_name}")


### all-MiniLM-L6-v2 + PCA

In [None]:
from sentence_transformers import SentenceTransformer
from sklearn.decomposition import PCA
import numpy as np

class PCASentenceTransformer(SentenceTransformer):
    """
    A SentenceTransformer model that applies PCA to reduce the dimensionality of the embeddings. 
    It serves as a wrapper to the inputted SentenceTransformer. 
    """
    def __init__(self, model_name_or_path, pca_components=128):
        super().__init__(model_name_or_path)
        self.pca_components = pca_components
        self.pca = None

    def fit_pca(self, embeddings):
        """Fits PCA on the provided embeddings."""
        self.pca = PCA(n_components=self.pca_components)
        self.pca.fit(embeddings)

    def encode(self, sentences, **kwargs):
        """Encodes the sentences and applies PCA to reduce dimensions."""
        # First, get the embeddings from the parent class method
        embeddings = super().encode(sentences, **kwargs)

        # If PCA is not fitted, fit it using the embeddings
        if self.pca is None:
            self.fit_pca(embeddings)

        # Transform the embeddings using the fitted PCA model
        reduced_embeddings = self.pca.transform(embeddings)
        return reduced_embeddings

In [None]:
pca_model = PCASentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', pca_components=128)
model_name = 'all-MiniLM-L6-v2_PCA_128'

# Define the BEIR tasks you want to evaluate on
tasks = mteb.get_tasks(tasks=["NFCorpus"])

# Evaluate the model on the benchmark
evaluation = mteb.MTEB(tasks=tasks)
results = evaluation.run(pca_model, output_folder=f"results/{model_name}")


### all-MiniLM-L6-v2 + Unsupervised Matryoshka Adaptor