# Retrofitting Embeddings

Baseline: Given a collection of documents D, extract their embeddings E (e.g., OpenAI emeddings). Then build an IR system and evaluate it in retrieving the right documents. 
Approach: Given a collection of documents D, extract their embeddings E (e.g., OpenAI emeddings). Then run dimensionality reduction on E (e.g, PCA with fast implementations). Then build an IR system and evaluate it in retrieving the right documents. 

In [23]:
import os
import pandas as pd
from sentence_transformers import SentenceTransformer
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.datasets.data_loader import GenericDataLoader
from torchdr import PCA
import beir.util as util

class IdentityReduction:
    def __init__(self, *args, **kwargs):
        pass

    def fit(self, x):
        pass

    def transform(self, x):
        return x

reduction_classes = {
    'x': IdentityReduction,
    'pca': PCA,
}
import torch
import numpy as np
class STWrapper(SentenceTransformer):
    def __init__(self, model_name, reduction_type='x', reduction_kwargs={}, fit_mode='query_fit', *args, **kwargs):
        super(STWrapper, self).__init__(model_name, *args, **kwargs)
        self.reduction_type = reduction_type
        self.reduction_kwargs = reduction_kwargs
        self.fit_mode = fit_mode
        self.reduction = reduction_classes[reduction_type](**reduction_kwargs) if reduction_type != 'x' else IdentityReduction()
        self.fitted = False
        self.corpus_embeddings = None
        self.query_embeddings = None
      
    def run_fit(self, embeddings):
        max_embeddings = 10000
        if len(embeddings) > max_embeddings:
            random_sample_idx = np.random.choice(len(embeddings), max_embeddings, replace=False)
            embeddings = embeddings[random_sample_idx]
        self.reduction.fit(embeddings)

    def fit_reduction(self, corpus, queries, *args, **kwargs):
        if self.fit_mode == 'joint_fit':
            self.corpus_embeddings = self.encode(corpus, *args, **kwargs)
            self.query_embeddings = self.encode(queries, *args, **kwargs)
            if isinstance(self.corpus_embeddings, torch.Tensor):
                combined = torch.cat([self.corpus_embeddings, self.query_embeddings], dim=0)
            else:
                combined = np.concatenate([self.corpus_embeddings, self.query_embeddings], axis=0)

            self.run_fit(combined)
        elif self.fit_mode == 'corpus_fit':
            self.corpus_embeddings = self.encode(corpus, *args, **kwargs)
            self.run_fit(self.corpus_embeddings)
        elif self.fit_mode == 'query_fit':
            self.query_embeddings = self.encode(queries, *args, **kwargs)
            self.run_fit(self.query_embeddings)
        self.fitted = True

    def encode_queries(self, queries, *args, **kwargs):
        embeddings = self.encode(queries, *args, **kwargs)
        if self.fitted:
            embeddings = self.reduction.transform(embeddings)
        return embeddings

    def encode_corpus(self, corpus, *args, **kwargs):
        embeddings = self.encode(corpus, *args, **kwargs)
        if self.fitted:
            embeddings = self.reduction.transform(embeddings)
        return embeddings
    


model_name_list = [
    # 'all-mpnet-base-v2',
    # 'all-MiniLM-L12-v1',
    # "sentence-t5-xl"
    # "hkunlp/instructor-large"
    "multi-qa-mpnet-base-cos-v1",
]
reduction_kwargs_choices = {
    'x': {},
    'pca': {'n_components': 128},
}
out_dir = "./beir"
os.makedirs(out_dir, exist_ok=True)

# dataset_list = ['scifact', 'hotpotqa', 'fiqa', 'fever']
# dataset_list = ['fiqa', ]
dataset_list = ['scidocs', ]


# reduction_list = [('x', {}), ('pca', {'n_components': 700}), ('pca', {'n_components': 256})]
reduction_list = [('x', {}), ('pca', {'n_components': 700}), ('pca', {'n_components': 300}), ]

# fit_modes = ['corpus_fit', 'joint_fit']
fit_modes = ['query_fit', ]


df_list = []
for dataset in dataset_list:
    url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
    data_path = util.download_and_unzip(url, out_dir)
    corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
    
    for model_name in model_name_list:
        for reduction_type, reduction_kwargs in reduction_list:
            cur_fit_modes = [''] if reduction_type == 'x' else fit_modes
            for fit_mode in cur_fit_modes:
                results_dict = {}
                components = reduction_kwargs.get('n_components', '')
                base_model = STWrapper(model_name, reduction_type, reduction_kwargs, fit_mode)

                corpus_ids = sorted(corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True)
                # TODO: currently using the default processing from sentence transformer to handle title + text,
                #  may change to other ways  
                corpus_to_fit = [corpus[cid] for cid in corpus_ids]
                queries_to_fit = [queries[qid] for qid in queries]

                base_model.fit_reduction(corpus_to_fit, queries_to_fit, convert_to_tensor=True)
                model = DRES(base_model, )
                retriever = EvaluateRetrieval(model, score_function="cos_sim")
                
                results = retriever.retrieve(corpus, queries)
                ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
                key = f"{model_name}+{reduction_type}+{components}+{fit_mode}"
                results_dict[key] = {k: round(v * 100, 1) for result in [ndcg, _map, recall, precision] for k, v in result.items()}
                
                df = pd.DataFrame(results_dict)
                df_list.append(df)

                import gc
                del base_model
                del model
                torch.cuda.empty_cache()
                gc.collect()


  0%|          | 0/25657 [00:00<?, ?it/s]



Batches:   0%|          | 0/8 [00:00<?, ?it/s]

Batches:   0%|          | 0/201 [00:00<?, ?it/s]



Batches:   0%|          | 0/8 [00:00<?, ?it/s]

Batches:   0%|          | 0/201 [00:00<?, ?it/s]



Batches:   0%|          | 0/8 [00:00<?, ?it/s]

Batches:   0%|          | 0/201 [00:00<?, ?it/s]

` subs sample some of them
` niche documents

not simple queries and documents
~ not frequent document
~ WHAT KIND OF DATA
~ bio medical
~ long context documents ( think sse)


SAE feature 

- structural match from SAE

In [24]:
# scidocs + 
all_df = pd.concat(df_list, axis=1)

all_df

Unnamed: 0,multi-qa-mpnet-base-cos-v1+x++,multi-qa-mpnet-base-cos-v1+pca+700+query_fit,multi-qa-mpnet-base-cos-v1+pca+300+query_fit
MAP@1,3.7,3.8,3.8
MAP@10,9.0,9.4,9.4
MAP@100,10.5,11.0,10.9
MAP@1000,10.7,11.2,11.2
MAP@3,6.5,6.9,6.8
MAP@5,7.8,8.2,8.1
NDCG@1,18.4,18.7,18.8
NDCG@10,15.6,16.3,16.1
NDCG@100,21.9,22.8,22.7
NDCG@1000,26.8,27.7,27.6


In [None]:
from datasets import load_dataset

ds = load_dataset("minimario/math-openwebmath-retrievals")

In [21]:
# scidocs + 
all_df = pd.concat(df_list, axis=1)

all_df

Unnamed: 0,hkunlp/instructor-large+x++,hkunlp/instructor-large+pca+700+query_fit,hkunlp/instructor-large+pca+300+query_fit
MAP@1,3.8,3.7,3.7
MAP@10,9.9,10.0,9.9
MAP@100,12.0,12.0,11.9
MAP@1000,12.3,12.3,12.2
MAP@3,7.0,7.0,6.9
MAP@5,8.4,8.5,8.3
NDCG@1,18.8,18.1,18.3
NDCG@10,17.3,17.3,17.1
NDCG@100,25.3,25.2,25.1
NDCG@1000,30.9,30.4,30.2


In [19]:
# scidocs + 
all_df = pd.concat(df_list, axis=1)

all_df

Unnamed: 0,multi-qa-mpnet-base-cos-v1+x++,multi-qa-mpnet-base-cos-v1+pca+300+query_fit
MAP@1,3.7,3.8
MAP@10,9.0,9.4
MAP@100,10.5,10.9
MAP@1000,10.7,11.2
MAP@3,6.5,6.8
MAP@5,7.8,8.1
NDCG@1,18.4,18.8
NDCG@10,15.6,16.1
NDCG@100,21.9,22.7
NDCG@1000,26.8,27.6


In [15]:
# nfcorpus + multi-qa-mpnet-base-cos-v1:
all_df = pd.concat(df_list, axis=1)

all_df

Unnamed: 0,multi-qa-mpnet-base-cos-v1+x++,multi-qa-mpnet-base-cos-v1+pca+300+query_fit,multi-qa-mpnet-base-cos-v1+pca+256+query_fit
MAP@1,5.9,5.3,5.3
MAP@10,11.5,11.0,11.0
MAP@100,14.4,13.9,13.9
MAP@1000,15.7,15.2,15.2
MAP@3,8.8,8.2,8.2
MAP@5,10.0,9.4,9.4
NDCG@1,43.8,41.0,40.7
NDCG@10,31.7,30.9,30.9
NDCG@100,29.2,28.7,28.6
NDCG@1000,38.1,37.4,37.4


In [2]:
# scidoc + t5 small
all_df = pd.concat(df_list, axis=1)

all_df

Unnamed: 0,sentence-t5-xl+x++,sentence-t5-xl+pca+720+query_fit,sentence-t5-xl+pca+512+query_fit
MAP@1,4.4,3.9,3.9
MAP@10,11.3,10.3,10.3
MAP@100,14.7,14.0,14.0
MAP@1000,16.3,15.7,15.7
MAP@3,7.8,6.9,6.9
MAP@5,9.3,8.4,8.4
NDCG@1,39.2,38.7,38.7
NDCG@10,32.3,31.2,31.2
NDCG@100,30.5,30.1,30.1
NDCG@1000,39.8,39.2,39.2


In [10]:
# scifac + t5 small
all_df = pd.concat(df_list, axis=1)

all_df

Unnamed: 0,sentence-t5-xl+x++,sentence-t5-xl+pca+720+query_fit
MAP@1,4.4,3.9
MAP@10,11.3,10.3
MAP@100,14.7,14.0
MAP@1000,16.3,15.7
MAP@3,7.8,6.9
MAP@5,9.3,8.4
NDCG@1,39.2,38.7
NDCG@10,32.3,31.2
NDCG@100,30.5,30.1
NDCG@1000,39.8,39.2


In [2]:
# scifac + t5 small
all_df = pd.concat(df_list, axis=1)

all_df

Unnamed: 0,sentence-t5-xl+x++,sentence-t5-xl+pca+700+query_fit
MAP@1,32.8,31.5
MAP@10,43.3,42.4
MAP@100,44.2,43.3
MAP@1000,44.3,43.3
MAP@3,40.3,39.0
MAP@5,42.1,41.3
NDCG@1,35.0,33.3
NDCG@10,48.4,48.1
NDCG@100,53.1,52.5
NDCG@1000,54.6,53.9


In [14]:
#nfcorpus
all_df = pd.concat(df_list, axis=1)
# all_df.to_csv(f"./RE-results.csv")
# for model in model_name_list:
#     col_pca = f"{model}+pca"
#     col_x = f"{model}+x"
#     all_df[f"{model} % change"] = ((all_df[col_pca] - all_df[col_x]) / all_df[col_x] * 100).round(1)
# all_df[sorted(all_df.columns, reverse=True)]
all_df

Unnamed: 0,all-mpnet-base-v2+x++,all-mpnet-base-v2+pca+700+query_fit,all-mpnet-base-v2+pca+700+corpus_fit,all-mpnet-base-v2+pca+700+joint_fit
MAP@1,5.0,4.2,3.3,3.3
MAP@10,11.7,10.4,8.6,8.8
MAP@100,15.3,14.2,11.8,12.1
MAP@1000,16.8,15.8,13.3,13.5
MAP@3,8.3,7.1,5.6,5.8
MAP@5,9.7,8.5,7.0,7.1
NDCG@1,39.5,37.6,33.7,34.1
NDCG@10,32.4,30.6,27.0,27.4
NDCG@100,31.0,29.8,26.1,26.5
NDCG@1000,40.0,38.6,35.0,35.4


In [None]:
# scidocs
all_df = pd.concat(df_list, axis=1)
# all_df.to_csv(f"./RE-results.csv")
# for model in model_name_list:
#     col_pca = f"{model}+pca"
#     col_x = f"{model}+x"
#     all_df[f"{model} % change"] = ((all_df[col_pca] - all_df[col_x]) / all_df[col_x] * 100).round(1)
# all_df[sorted(all_df.columns, reverse=True)]
all_df

Unnamed: 0,all-mpnet-base-v2+x++,all-mpnet-base-v2+pca+700+query_fit
MAP@1,4.9,4.8
MAP@10,13.3,13.1
MAP@100,15.9,15.6
MAP@1000,16.2,15.9
MAP@3,9.2,8.9
MAP@5,11.2,11.0
NDCG@1,23.9,23.3
NDCG@10,22.3,21.9
NDCG@100,31.8,31.2
NDCG@1000,37.2,36.5


In [6]:
# fiqa query fit
all_df = pd.concat(df_list, axis=1)
# all_df.to_csv(f"./RE-results.csv")
# for model in model_name_list:
#     col_pca = f"{model}+pca"
#     col_x = f"{model}+x"
#     all_df[f"{model} % change"] = ((all_df[col_pca] - all_df[col_x]) / all_df[col_x] * 100).round(1)
# all_df[sorted(all_df.columns, reverse=True)]
all_df

Unnamed: 0,all-mpnet-base-v2+x++,all-mpnet-base-v2+pca+700+query_fit
MAP@1,24.4,24.4
MAP@10,41.1,41.1
MAP@100,43.2,43.3
MAP@1000,43.3,43.4
MAP@3,35.7,36.1
MAP@5,38.7,39.1
NDCG@1,49.1,48.3
NDCG@10,50.0,49.7
NDCG@100,56.6,56.5
NDCG@1000,58.9,58.9


In [6]:
all_df = pd.concat(df_list, axis=1)
# all_df.to_csv(f"./RE-results.csv")
# for model in model_name_list:
#     col_pca = f"{model}+pca"
#     col_x = f"{model}+x"
#     all_df[f"{model} % change"] = ((all_df[col_pca] - all_df[col_x]) / all_df[col_x] * 100).round(1)
# all_df[sorted(all_df.columns, reverse=True)]
all_df

Unnamed: 0,all-mpnet-base-v2+x++,all-mpnet-base-v2+pca+700+corpus_fit,all-mpnet-base-v2+pca+700+joint_fit,all-mpnet-base-v2+pca+128+corpus_fit,all-mpnet-base-v2+pca+128+joint_fit
MAP@1,47.8,48.9,48.9,42.9,45.4
MAP@10,58.4,58.1,58.1,53.2,54.5
MAP@100,59.4,59.1,59.1,54.2,55.6
MAP@1000,59.5,59.1,59.1,54.3,55.6
MAP@3,55.3,55.1,55.1,50.4,51.8
MAP@5,57.4,56.9,56.9,52.1,53.5
NDCG@1,49.7,51.0,51.0,45.0,47.7
NDCG@10,63.3,62.9,62.9,58.2,59.2
NDCG@100,67.4,67.0,67.0,62.6,63.8
NDCG@1000,68.0,67.7,67.7,63.7,64.8


In [10]:
all_df = pd.concat(df_list, axis=1)
# all_df.to_csv(f"./RE-results.csv")
# for model in model_name_list:
#     col_pca = f"{model}+pca"
#     col_x = f"{model}+x"
#     all_df[f"{model} % change"] = ((all_df[col_pca] - all_df[col_x]) / all_df[col_x] * 100).round(1)
# all_df[sorted(all_df.columns, reverse=True)]
all_df

Unnamed: 0,all-mpnet-base-v2+x+,all-mpnet-base-v2+pca+360,all-mpnet-base-v2+pca+128
MAP@1,47.8,48.5,45.9
MAP@10,58.4,58.2,56.1
MAP@100,59.4,59.2,57.1
MAP@1000,59.5,59.2,57.1
MAP@3,55.3,55.5,53.2
MAP@5,57.4,56.9,54.8
NDCG@1,49.7,50.7,48.0
NDCG@10,63.3,63.1,61.1
NDCG@100,67.4,67.0,65.5
NDCG@1000,68.0,67.7,66.2


In [12]:
all_df = pd.concat(df_list, axis=1)
# all_df.to_csv(f"./RE-results.csv")
# all_df
# for model in model_name_list:
#     col_pca = f"{model}+pca"
#     col_x = f"{model}+x"
#     all_df[f"{model} % change"] = ((all_df[col_pca] - all_df[col_x]) / all_df[col_x] * 100).round(1)
# all_df[sorted(all_df.columns, reverse=True)]

Unnamed: 0,all-mpnet-base-v2+x,all-mpnet-base-v2+pca,all-mpnet-base-v2 % change,all-MiniLM-L12-v1+x,all-MiniLM-L12-v1+pca,all-MiniLM-L12-v1 % change
MAP@1,0.47789,0.45872,-4.0,0.48139,0.45206,-6.1
MAP@10,0.58434,0.56075,-4.0,0.57537,0.54391,-5.5
MAP@100,0.59424,0.57104,-3.9,0.58381,0.55368,-5.2
MAP@1000,0.59454,0.57132,-3.9,0.58409,0.55398,-5.2
MAP@3,0.55327,0.53157,-3.9,0.54461,0.51169,-6.0
MAP@5,0.57367,0.54806,-4.5,0.56156,0.52871,-5.8
NDCG@1,0.49667,0.48,-3.4,0.49667,0.47,-5.4
NDCG@10,0.63309,0.61067,-3.5,0.6217,0.59294,-4.6
NDCG@100,0.67379,0.65479,-2.8,0.6599,0.63539,-3.7
NDCG@1000,0.68032,0.66179,-2.7,0.66767,0.64449,-3.5
