In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoModel, AutoTokenizer

In [4]:
# Sentence Classification Model Definition

class SentenceClassificationTransformer(nn.Module):
    def __init__(self, model_name="bert-base-uncased", num_classes=3):
        super(SentenceClassificationTransformer, self).__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        # Mean Pooling Layer
        self.pooling = nn.AdaptiveAvgPool1d(1)

        # Task A: Sentence Classification
        self.classifier = nn.Linear(self.encoder.config.hidden_size, num_classes)
    def mean_pooling(self, token_embeddings, attention_mask):
        """Compute mean pooling over token embeddings based on attention mask"""
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)  # Avoid division by zero
        return sum_embeddings / sum_mask

    # Processing Sentences and Forward Pass
    def forward(self, sentences):
        inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
        outputs = self.encoder(**inputs)
        pooled_output = self.mean_pooling(outputs.last_hidden_state, inputs["attention_mask"])

        # Task A: Sentence Classification
        class_logits = self.classifier(pooled_output)

        return class_logits

# Initialize sentence classification model
sentence_classification_model = SentenceClassificationTransformer()
# Defining Loss Function and Optimizer**

# Define Loss Function
classification_loss_fn = nn.CrossEntropyLoss()

# Define Optimizer
optimizer = optim.AdamW(sentence_classification_model.parameters(), lr=5e-5)

# Training Step: Forward Pass + Loss Computation

sample_sentences = ["This is a test sentence.", "Sentence transformers generate embeddings."]
sentence_labels = torch.tensor([0, 1])

# Forward Pass
class_logits = sentence_classification_model(sample_sentences)

# Compute Loss
classification_loss = classification_loss_fn(class_logits, sentence_labels)


optimizer.zero_grad()
classification_loss.backward()
optimizer.step()

print("Sentence Classification Output Shape:", class_logits.shape)  # Expected: (batch_size, num_classes)


Sentence Classification Output Shape: torch.Size([2, 3])
