# Variant 2: TFIDF-SRT-EMB-LegalBERT
#### Overview:
• This variant follows the same deduplication and TF-IDF-based reordering as Variant 1.
• In addition, the model includes a TF-IDF embedding layer.
• For each token in the input, you compute (or look up) its TF-IDF score, convert that into a “bucket” (or discretized index) based on predetermined bins, and then use a learnable embedding to map that bucket to a vector.
• This additional embedding is then added to the regular token embedding.

## Explanation Variant 2:

##### Preprocessing:
– We use a nearly identical preprocessing function as in Variant 1. In addition to returning the processed text, it also returns a list of the TF-IDF scores for the ordered tokens.
– A helper function, bucketize_tfidf_score, maps each continuous TF-IDF score into a discrete bucket (the number of buckets is a hyperparameter, e.g., 32).

##### TFIDF_EMB_LegalBERT Model:
– The model wraps a standard LegalBERT and adds a learnable embedding layer (tfidf_embedding) that maps bucket indices to embedding vectors.
– In the forward pass, the standard token embeddings (obtained from LegalBERT’s embedding layer) are combined with the projected TF-IDF embeddings (after passing through a linear projection to align dimensions).
– The combined embeddings are then passed to LegalBERT using the inputs_embeds parameter.

##### Data Preparation:
– After preprocessing, we tokenize the processed text and simultaneously prepare the corresponding TF-IDF bucket IDs for every token.
– These bucket IDs are then provided to the model along with the usual inputs.

##### Training / Inference:
– With this modified model, you can fine-tune on your classification task and compare the performance and efficiency to the baseline and to Variant 1.


Below is an example implementation. In this example, we override the forward pass to add the extra TF-IDF embedding to the standard token embeddings.

In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Assume we reuse the same model name and tokenizer as before.
model_name = "nlpaueb/legal-bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load the base LegalBERT model.
base_model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Define a new model class for TFIDF-SRT-EMB-LegalBERT.
class TFIDF_EMB_LegalBERT(nn.Module):
    def __init__(self, base_model, num_buckets=32, embedding_dim=50):
        """
        Args:
            base_model: The pre-trained LegalBERT model.
            num_buckets: Number of TF-IDF buckets to discretize the continuous scores.
            embedding_dim: Dimension for the TF-IDF embeddings.
        """
        super(TFIDF_EMB_LegalBERT, self).__init__()
        self.base_model = base_model
        # Define an extra embedding layer for TF-IDF buckets.
        self.tfidf_embedding = nn.Embedding(num_buckets, embedding_dim)
        self.num_buckets = num_buckets
        self.embedding_dim = embedding_dim
        
        # A simple projection to match base model's hidden size if needed.
        hidden_size = base_model.config.hidden_size
        self.proj = nn.Linear(embedding_dim, hidden_size)
    
    def forward(self, input_ids, attention_mask, tfidf_bucket_ids):
        """
        Args:
            input_ids: Tensor of token ids from the tokenizer.
            attention_mask: Attention mask for padding tokens.
            tfidf_bucket_ids: Tensor of the same shape as input_ids,
                              indicating the bucket index for each token's TF-IDF score.
        """
        # Get the standard token embeddings from the base model.
        # Note: You can access the embeddings via base_model.bert.embeddings.word_embeddings for BERT.
        token_embeddings = self.base_model.bert.embeddings.word_embeddings(input_ids)
        
        # Get the TF-IDF embeddings.
        bucket_embeddings = self.tfidf_embedding(tfidf_bucket_ids)
        # Project the TF-IDF embeddings into the same space as token embeddings.
        bucket_embeddings_proj = self.proj(bucket_embeddings)
        
        # Combine the embeddings (e.g., add them).
        combined_embeddings = token_embeddings + bucket_embeddings_proj
        
        # Now, call the rest of the model. We need to supply these combined embeddings.
        # One way is to feed them to the encoder. Many Hugging Face models allow you to override embeddings via the "inputs_embeds" parameter.
        outputs = self.base_model(
            attention_mask=attention_mask,
            inputs_embeds=combined_embeddings,
        )
        return outputs

# For simplicity, assume we have a helper function to bucketize TF-IDF scores.
def bucketize_tfidf_score(score, num_buckets, min_score=0.0, max_score=10.0):
    """
    Map a continuous TF-IDF score to a discrete bucket.
    You can adjust min_score and max_score based on your corpus statistics.
    """
    # Normalize score to [0, 1]
    norm_score = (score - min_score) / (max_score - min_score)
    norm_score = max(0.0, min(1.0, norm_score))
    bucket = int(norm_score * (num_buckets - 1))
    return bucket

# Using the same preprocess function as Variant 1 to deduplicate and reorder tokens.
def preprocess_document_bow_variant(doc, tokenizer, idf_dict, max_length=512):
    tokens = tokenizer.tokenize(doc)
    unique_tokens = list(dict.fromkeys(tokens))
    token_scores = [(token, idf_dict.get(token, 0)) for token in unique_tokens]
    token_scores_sorted = sorted(token_scores, key=lambda x: x[1], reverse=True)
    ordered_tokens = [token for token, score in token_scores_sorted][:max_length]
    processed_text = " ".join(ordered_tokens)
    # Also return the corresponding TF-IDF scores for bucketization.
    ordered_scores = [score for token, score in token_scores_sorted][:max_length]
    return processed_text, ordered_scores

# Assume we already computed the idf_dict as in Variant 1.
# Preprocess the documents and compute the bucket IDs per token.
processed_docs_variant2 = []
tfidf_buckets_list = []  # List of lists for bucket ids
for doc in documents:
    proc_text, scores = preprocess_document_bow_variant(doc, tokenizer, idf_dict)
    processed_docs_variant2.append(proc_text)
    # Compute bucket indices for each token based on the TF-IDF score.
    bucket_ids = [bucketize_tfidf_score(score, num_buckets=32) for score in scores]
    tfidf_buckets_list.append(bucket_ids)

# Tokenize the preprocessed documents.
encoded_inputs = tokenizer(processed_docs_variant2, padding=True, truncation=True,
                           max_length=512, return_tensors="pt")
input_ids = encoded_inputs["input_ids"]
attention_mask = encoded_inputs["attention_mask"]

# For demonstration, create a tensor for tfidf_bucket_ids that aligns with input_ids.
# Here we assume the tokenization of the processed text gives the same order as the list of bucket ids.
# In a more careful implementation, you would map each token back to its bucket.
# For simplicity, we pad/truncate the bucket ids list to match the token length.
max_len = input_ids.size(1)
tfidf_bucket_ids = []
for bucket_ids in tfidf_buckets_list:
    bucket_ids = bucket_ids[:max_len]
    # Pad with 0s if necessary
    if len(bucket_ids) < max_len:
        bucket_ids.extend([0] * (max_len - len(bucket_ids)))
    tfidf_bucket_ids.append(bucket_ids)
tfidf_bucket_ids = torch.tensor(tfidf_bucket_ids)

# Create an instance of the new model variant.
model_variant2 = TFIDF_EMB_LegalBERT(base_model, num_buckets=32, embedding_dim=50)

# Forward pass using variant 2.
outputs_variant2 = model_variant2(input_ids=input_ids, attention_mask=attention_mask,
                                  tfidf_bucket_ids=tfidf_bucket_ids)
loss_variant2 = outputs_variant2.loss
logits_variant2 = outputs_variant2.logits

print("Variant 2 logits:", logits_variant2)
