In [None]:
#%%

import torch
from sklearn.feature_extraction.text import TfidfVectorizer
from transformers import AutoTokenizer
from typing import List

BPE_TOKENIZERS = [
    "openai-gpt",
    "t5-base",
    "roberta-base",
    "facebook/bart-base",
    "xlnet-base-cased",
    "bert-base-uncased",
    "albert-base-v2",
    "distilbert-base-uncased",
    "ctrl",
    "google/electra-small-generator",
]


class SparseVectorizer:
    def __init__(self, model_name: str = 'bert-base-uncased', k1: float = 1.5, b: float = 0.75):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.k1 = k1
        self.b = b
        self.vectorizer = self._create_vectorizer()
        self.idf = None
        self.doc_len = None
        self.avgdl = None

    def _create_vectorizer(self):
        return TfidfVectorizer(
            tokenizer=self.tokenize,
            lowercase=True,
            use_idf=True,
            norm=None,
            smooth_idf=False
        )

    def tokenize(self, text: str) -> List[str]:
        token_ids = self.tokenizer.encode(text, add_special_tokens=False)
        return [self.tokenizer.decode([id]) for id in token_ids]

    @property
    def feature_names(self) -> List[str]:
        return self.vectorizer.get_feature_names_out()

    def fit(self, documents: List[str]):
        X = self.vectorizer.fit_transform(documents)
        self.idf = torch.FloatTensor(self.vectorizer.idf_)
        X_torch = torch.FloatTensor(X.toarray())
        self.doc_len = X_torch.sum(dim=1)
        self.avgdl = self.doc_len.mean()
        
        # Calculate BM25 weights for the document corpus
        len_d = self.doc_len.unsqueeze(1)
        numerator = X_torch * (self.idf * (self.k1 + 1)).unsqueeze(0)
        denominator = X_torch * (self.k1 * (1 - self.b + self.b * len_d / self.avgdl)) + 1.0
        self.bm25_weights = numerator / denominator
        
        return self

    def transform(self, documents: List[str]) -> torch.Tensor:
        X = self.vectorizer.transform(documents)
        X_torch = torch.FloatTensor(X.toarray())
        
        # Apply IDF weights only, as BM25 document-specific factors are pre-computed
        vectors = X_torch * self.idf.unsqueeze(0)

        # Apply BM25 factors
        vectors = vectors * (self.k1 + 1) / (vectors + self.k1)

        return vectors

    def query(self, queries: List[str]) -> torch.Tensor:
        query_vectors = self.transform(queries)
        
        # Compute similarity with pre-weighted documents
        dot_product = torch.mm(query_vectors, self.bm25_weights.t())
        query_norm = torch.norm(query_vectors, dim=1).unsqueeze(1)
        doc_norm = torch.norm(self.bm25_weights, dim=1).unsqueeze(0)
        
        similarity = dot_product / (query_norm * doc_norm)
        
        # Replace NaN values with zeros
        similarity = torch.nan_to_num(similarity, nan=0.0)
        
        return similarity

    @classmethod
    def from_documents(cls, documents: List[str], model_name: str = 'ctrl'):
        vectorizer = cls(model_name=model_name)
        vectorizer.fit(documents)
        return vectorizer.transform(documents)

    @classmethod
    def fit_transform_query(cls, documents: List[str], queries: List[str], model_name: str = 'ctrl'):
        vectorizer = cls(model_name=model_name)
        vectorizer.fit(documents)
        return vectorizer.query(queries, documents)

documents = [
    "Artificial General Intelligence is a subfield of Machine Learning.",
    "Natural Language Processing is crucial for AGI development.",
    "The AGI researcher published a paper on advanced NLP techniques."
]
queries = ["What is the relationship between AGI and NLP?", "Why do I care about AGI?"]

results = {}
features = {}
shapes = {}
# Example usage:
encoder = 'bert-base-uncased'
vectorizer = SparseVectorizer(model_name=encoder)

vectorizer.fit(documents)
features[encoder] = vectorizer.feature_names
vector_matrix = vectorizer.transform(documents)
shapes[encoder] = vector_matrix.shape

#%%
# Scoring example
query(vectorizer, queries)

# %%
