# Retrofitting Embeddings

Baseline: Given a collection of documents D, extract their embeddings E (e.g., OpenAI emeddings). Then build an IR system and evaluate it in retrieving the right documents. 
Approach: Given a collection of documents D, extract their embeddings E (e.g., OpenAI emeddings). Then run dimensionality reduction on E (e.g, PCA with fast implementations). Then build an IR system and evaluate it in retrieving the right documents. 

In [1]:
from beir import util, LoggingHandler
from beir.retrieval import models
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES

import os

# import logging

# #### Just some code to print debug information to stdout
# logging.basicConfig(format='%(asctime)s - %(message)s',
#                     datefmt='%Y-%m-%d %H:%M:%S',
#                     level=logging.INFO,
#                     handlers=[LoggingHandler()])


from torchdr import PCA, TSNE, KernelPCA
from sentence_transformers import SentenceTransformer
import pandas as pd


  from tqdm.autonotebook import tqdm


In [2]:
# model_name_list =[
#     # 'all-mpnet-base-v2', #  0.5481, 0.2312, 0.2559 (1 mins) (2312, 2559)
#     # # 'all-mpnet-base-v2', #[full 10] 2414 -> 2684kwo, 2440kw, 2298kp, 2740kpo, 2352kso, 2399ks
#     # 'sentence-t5-xl', #[full] 0.6754, 0.2543, 0.2990

#     # 'all-MiniLM-L12-v1',
#     'all-mpnet-base-v2',
# ]

# reduction_classes = {
#     'pca': PCA,
#     # 'tsne': TSNE
# }

# class idenity_reduction:

#     def __init__(self, *args, **kwargs):
#         pass

#     def fit(self, x):
#         pass

#     def transform(self, x):
#         return x

# class ST_wrapper(SentenceTransformer):
#     def __init__(self, model_name, reduction_type = 'x', reduction_kwargs={}, *args, **kwargs):
#         super(ST_wrapper, self).__init__(model_name, *args, **kwargs)
#         if reduction_type == 'x':
#             self.reduction = idenity_reduction()
#         else:
#             self.reduction = reduction_classes[reduction_type](**reduction_kwargs)
#         print(reduction_kwargs)

#     def encode_queries(self, queries, *args, **kwargs):
#         embeddings = self.encode(queries, *args, **kwargs)
#         self.reduction.fit(embeddings)
#         return self.reduction.transform(embeddings)

#     def encode_corpus(self, corpus, *args, **kwargs):
#         embeddings = self.encode(corpus, *args, **kwargs)
#         print(embeddings.shape)
#         return self.reduction.transform(embeddings)


# reduction_kwargs_choices = {
#     'x': {},
#     'pca': {'n_components': 128},
#     # 'tsne': {'perplexity': 30}s
# }


# out_dir = "./beir"
# os.mkdir(out_dir) if not os.path.exists(out_dir) else None
# #### /print debug information to stdout

# #### Download scifact.zip dataset and unzip the dataset
# # dataset_list = ['scifact', 'hotpotqa', 'fiqa', 'fever']
# dataset_list = ['scifact',]


# reduction_list = [
#     ('x', {}),
# ] + [('pca', {'n_components': n}) for n in [360, 128]]

# df_list = []
# for dataset in dataset_list:

#     # dataset = "scidocs"
#     url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)

#     data_path = util.download_and_unzip(url, out_dir)

#     #### Provide the data_path where scifact has been downloaded and unzipped
#     corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")


#     for model_name in model_name_list:
#         for reduction_type, reduction_kwargs in reduction_list:
#             results_dict = {}
#             # reduction_type = reduction_type 
#             components = reduction_kwargs.get('n_components', '')
#             model = ST_wrapper(model_name, reduction_type, reduction_kwargs)
#             model = DRES(model, batch_size=128)

#             retriever = EvaluateRetrieval(model, score_function="dot") # or "cos_sim" for cosine similarity
#             results = retriever.retrieve(corpus, queries)
#             ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
#             results_dict[f"{model_name}+{reduction_type}+{components}"] = {}
#             for result in [ndcg, _map, recall, precision]:
#                 for k, v in result.items():
#                     results_dict[f"{model_name}+{reduction_type}+{components}"][k] = round(v*100, 1)

#             df = pd.DataFrame(results_dict)
#             df_list.append(df)

In [None]:
import os
import pandas as pd
from sentence_transformers import SentenceTransformer
from beir.retrieval.evaluation import EvaluateRetrieval
from beir.retrieval.search.dense import DenseRetrievalExactSearch as DRES
from beir.datasets.data_loader import GenericDataLoader
from torchdr import PCA
import beir.util as util

class IdentityReduction:
    def __init__(self, *args, **kwargs):
        pass

    def fit(self, x):
        pass

    def transform(self, x):
        return x

reduction_classes = {
    'x': IdentityReduction,
    'pca': PCA,
}
import torch
import numpy as np
class STWrapper(SentenceTransformer):
    def __init__(self, model_name, reduction_type='x', reduction_kwargs={}, fit_mode='query_fit', *args, **kwargs):
        super(STWrapper, self).__init__(model_name, *args, **kwargs)
        self.reduction_type = reduction_type
        self.reduction_kwargs = reduction_kwargs
        self.fit_mode = fit_mode
        self.reduction = reduction_classes[reduction_type](**reduction_kwargs) if reduction_type != 'x' else IdentityReduction()
        self.fitted = False
        self.corpus_embeddings = None
        self.query_embeddings = None
      
    def fit_reduction(self, corpus, queries, *args, **kwargs):
        if self.fit_mode == 'joint_fit':
            self.corpus_embeddings = self.encode(corpus, *args, **kwargs)
            self.query_embeddings = self.encode(queries, *args, **kwargs)
            if isinstance(self.corpus_embeddings, torch.Tensor):
                combined = torch.cat([self.corpus_embeddings, self.query_embeddings], dim=0)
            else:
                combined = np.concatenate([self.corpus_embeddings, self.query_embeddings], axis=0)

            self.reduction.fit(combined)
        elif self.fit_mode == 'corpus_fit':
            self.corpus_embeddings = self.encode(corpus, *args, **kwargs)
            self.reduction.fit(self.corpus_embeddings)
        elif self.fit_mode == 'query_fit':
            self.query_embeddings = self.encode(queries, *args, **kwargs)
            self.reduction.fit(self.query_embeddings)
        self.fitted = True

    def encode_queries(self, queries, *args, **kwargs):
        embeddings = self.encode(queries, *args, **kwargs)
        if self.fitted:
            embeddings = self.reduction.transform(embeddings)
        return embeddings

    def encode_corpus(self, corpus, *args, **kwargs):
        embeddings = self.encode(corpus, *args, **kwargs)
        if self.fitted:
            embeddings = self.reduction.transform(embeddings)
        return embeddings
    


model_name_list = [
    'all-mpnet-base-v2',
    # 'all-MiniLM-L12-v1',
]
reduction_kwargs_choices = {
    'x': {},
    'pca': {'n_components': 128},
}
out_dir = "./beir"
os.makedirs(out_dir, exist_ok=True)

# dataset_list = ['scifact', 'hotpotqa', 'fiqa', 'fever']
dataset_list = ['scifact', ]

reduction_list = [('x', {}), ('pca', {'n_components': 360}), ('pca', {'n_components': 128})]
fit_modes = ['corpus_fit', 'joint_fit']

df_list = []
for dataset in dataset_list:
    url = f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip"
    data_path = util.download_and_unzip(url, out_dir)
    corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")
    
    for model_name in model_name_list:
        for reduction_type, reduction_kwargs in reduction_list:
            for fit_mode in fit_modes:
                results_dict = {}
                components = reduction_kwargs.get('n_components', '')
                base_model = STWrapper(model_name, reduction_type, reduction_kwargs, fit_mode)

                corpus_ids = sorted(corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True)
                # TODO: currently using the default processing from sentence transformer to handle title + text,
                #  may change to other ways  
                corpus_to_fit = [corpus[cid] for cid in corpus_ids]
                queries_to_fit = [queries[qid] for qid in queries]

                base_model.fit_reduction(corpus_to_fit, queries_to_fit, convert_to_tensor=True)
                model = DRES(base_model, batch_size=32)
                retriever = EvaluateRetrieval(model, score_function="dot")
                
                results = retriever.retrieve(corpus, queries)
                ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
                key = f"{model_name}+{reduction_type}+{components}+{fit_mode}"
                results_dict[key] = {k: round(v * 100, 1) for result in [ndcg, _map, recall, precision] for k, v in result.items()}
                
                df = pd.DataFrame(results_dict)
                df_list.append(df)


  0%|          | 0/5183 [00:00<?, ?it/s]



Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/162 [00:00<?, ?it/s]



Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/162 [00:00<?, ?it/s]



Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/162 [00:00<?, ?it/s]



Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/162 [00:00<?, ?it/s]



Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Batches:   0%|          | 0/162 [00:00<?, ?it/s]



` subs sample some of them
` niche documents

not simple queries and documents
~ not frequent document
~ WHAT KIND OF DATA
~ bio medical
~ long context documents ( think sse)


SAE feature 

- structural match from SAE

In [None]:
all_df = pd.concat(df_list, axis=1)
# all_df.to_csv(f"./RE-results.csv")
# for model in model_name_list:
#     col_pca = f"{model}+pca"
#     col_x = f"{model}+x"
#     all_df[f"{model} % change"] = ((all_df[col_pca] - all_df[col_x]) / all_df[col_x] * 100).round(1)
# all_df[sorted(all_df.columns, reverse=True)]
all_df

Unnamed: 0,all-MiniLM-L12-v1+x++corpus_fit,all-MiniLM-L12-v1+x++joint_fit,all-MiniLM-L12-v1+pca+360+corpus_fit,all-MiniLM-L12-v1+pca+360+joint_fit,all-MiniLM-L12-v1+pca+128+corpus_fit,all-MiniLM-L12-v1+pca+128+joint_fit
MAP@1,48.1,48.1,43.7,43.7,37.0,37.9
MAP@10,57.5,57.5,53.8,53.7,48.0,48.9
MAP@100,58.4,58.4,54.6,54.6,49.1,50.0
MAP@1000,58.4,58.4,54.7,54.7,49.2,50.1
MAP@3,54.5,54.5,50.4,50.4,45.0,46.4
MAP@5,56.2,56.2,52.0,52.1,46.7,47.8
NDCG@1,49.7,49.7,45.0,45.0,38.3,39.0
NDCG@10,62.2,62.2,59.0,58.9,53.2,53.9
NDCG@100,66.0,66.0,62.9,63.0,58.0,58.8
NDCG@1000,66.8,66.8,63.8,63.9,59.2,60.0


In [10]:
all_df = pd.concat(df_list, axis=1)
# all_df.to_csv(f"./RE-results.csv")
# for model in model_name_list:
#     col_pca = f"{model}+pca"
#     col_x = f"{model}+x"
#     all_df[f"{model} % change"] = ((all_df[col_pca] - all_df[col_x]) / all_df[col_x] * 100).round(1)
# all_df[sorted(all_df.columns, reverse=True)]
all_df

Unnamed: 0,all-mpnet-base-v2+x+,all-mpnet-base-v2+pca+360,all-mpnet-base-v2+pca+128
MAP@1,47.8,48.5,45.9
MAP@10,58.4,58.2,56.1
MAP@100,59.4,59.2,57.1
MAP@1000,59.5,59.2,57.1
MAP@3,55.3,55.5,53.2
MAP@5,57.4,56.9,54.8
NDCG@1,49.7,50.7,48.0
NDCG@10,63.3,63.1,61.1
NDCG@100,67.4,67.0,65.5
NDCG@1000,68.0,67.7,66.2


In [12]:
all_df = pd.concat(df_list, axis=1)
# all_df.to_csv(f"./RE-results.csv")
# all_df
# for model in model_name_list:
#     col_pca = f"{model}+pca"
#     col_x = f"{model}+x"
#     all_df[f"{model} % change"] = ((all_df[col_pca] - all_df[col_x]) / all_df[col_x] * 100).round(1)
# all_df[sorted(all_df.columns, reverse=True)]

Unnamed: 0,all-mpnet-base-v2+x,all-mpnet-base-v2+pca,all-mpnet-base-v2 % change,all-MiniLM-L12-v1+x,all-MiniLM-L12-v1+pca,all-MiniLM-L12-v1 % change
MAP@1,0.47789,0.45872,-4.0,0.48139,0.45206,-6.1
MAP@10,0.58434,0.56075,-4.0,0.57537,0.54391,-5.5
MAP@100,0.59424,0.57104,-3.9,0.58381,0.55368,-5.2
MAP@1000,0.59454,0.57132,-3.9,0.58409,0.55398,-5.2
MAP@3,0.55327,0.53157,-3.9,0.54461,0.51169,-6.0
MAP@5,0.57367,0.54806,-4.5,0.56156,0.52871,-5.8
NDCG@1,0.49667,0.48,-3.4,0.49667,0.47,-5.4
NDCG@10,0.63309,0.61067,-3.5,0.6217,0.59294,-4.6
NDCG@100,0.67379,0.65479,-2.8,0.6599,0.63539,-3.7
NDCG@1000,0.68032,0.66179,-2.7,0.66767,0.64449,-3.5
