### **Full Explanation of the Sentence Transformer Implementation (PyTorch)**

This script implements a **Sentence Transformer** model using **PyTorch** and **Hugging Face's `transformers` library**. The model encodes input sentences into fixed-length embeddings using a **pre-trained transformer** (like BERT) and applies **mean pooling** to aggregate token representations into a single sentence vector.

---

In [3]:
# Importing Required Libraries.
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer

- `torch` and `torch.nn` are used for defining and handling the PyTorch model.
- `AutoModel` and `AutoTokenizer` from Hugging Face are used to load a **pre-trained transformer model** and its corresponding tokenizer.

**Explanation:**
- The class `SentenceTransformer` inherits from `nn.Module`, making it a PyTorch model.
- `AutoModel.from_pretrained(model_name)` loads a **pre-trained transformer model** (`bert-base-uncased` by default).
- `AutoTokenizer.from_pretrained(model_name)` loads the **corresponding tokenizer** for text processing.


- Transformer models like BERT generate **token embeddings** (one vector per token).
- We need a **fixed-size sentence embedding**.
- **Mean pooling** computes the average of all token embeddings, weighted by the **attention mask** (to ignore padding tokens).


In [4]:
# Defining the Sentence Transformer Model.

class SentenceTransformer(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super(SentenceTransformer, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Implementing Mean Pooling.
    def mean_pooling(self, token_embeddings, attention_mask):
        """Compute mean pooling over token embeddings based on attention mask"""
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)  # Avoid division by zero
        return sum_embeddings / sum_mask

    # Forward Pass (Encoding Sentences)
    def forward(self, sentences):
        inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
        outputs = self.encoder(**inputs)
        sentence_embeddings = self.mean_pooling(outputs.last_hidden_state, inputs["attention_mask"])
        return sentence_embeddings

### **Step-by-Step Breakdown:**
1. **Expand the Attention Mask:**  
   - `attention_mask.unsqueeze(-1).expand(token_embeddings.size())`  
   - This ensures padding tokens don't contribute to the mean.
2. **Multiply Token Embeddings by Attention Mask:**  
   - This nullifies embeddings for padding tokens.
3. **Sum the Token Embeddings Along Dimension 1:**  
   - Computes the total sum of valid token embeddings.
4. **Compute the Mean:**  
   - Divide by the **sum of the attention mask** to get the **average** embedding.

### **Processing Input Sentences:**
1. **Tokenization:**
   - `self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")`
   - Converts input text into tokenized format compatible with the transformer.
   - `padding=True` ensures uniform input length.
   - `truncation=True` prevents excessively long inputs.
   - `return_tensors="pt"` returns **PyTorch tensors**.

2. **Passing Through Transformer:**
   - `outputs = self.encoder(**inputs)`
   - Generates **contextualized token embeddings** from the transformer model.

3. **Applying Mean Pooling:**
   - Converts token embeddings into a **single sentence embedding**.

4. **Returning the Sentence Embeddings:**
   - The output is a **fixed-length representation** of the sentence.

In [5]:
# Model Initialization and Testing
sentence_transformer = SentenceTransformer()
sample_sentences = ["This is a test sentence.", "Sentence transformers generate embeddings."]

with torch.no_grad():
    embeddings = sentence_transformer(sample_sentences)

print("Embeddings shape:", embeddings.shape)  # Expected: (batch_size, hidden_size)
print("Sample Embeddings:", embeddings)


Embeddings shape: torch.Size([2, 768])
Sample Embeddings: tensor([[ 6.6063e-02, -2.1769e-01, -1.5390e-01,  ..., -2.2918e-01,
         -3.9263e-04,  3.9901e-01],
        [ 6.9934e-02, -3.2461e-01, -2.9427e-01,  ..., -1.5882e-01,
         -5.1646e-01,  1.5244e-01]])
