In [None]:
# Task 2: Multi-Task Learning Expansion

'''
Goals:

Task A: Sentence Classification (e.g., classify sentence into News, Sports, Finance, etc.)

Task B: Another NLP task — we'll choose Sentiment Analysis (Positive, Negative, Neutral).

Design:

Use SentenceTransformer('paraphrase-MiniLM-L6-v2') for encoding input sentences.

Add two separate classification heads:

task_a_head: MLP for topic classification.

task_b_head: MLP for sentiment analysis.

Both heads take the same sentence embedding as input (from the frozen or fine-tuned backbone).
This architecture is flexible: Easy to extend to more tasks by adding new heads.
'''




from sentence_transformers import SentenceTransformer
import torch
from torch import nn

class MultiTaskSentenceTransformer(nn.Module):
    def __init__(self, model_name='paraphrase-MiniLM-L6-v2',
                 task_a_num_classes=3,  # e.g., ['News', 'Sports', 'Finance']
                 task_b_num_classes=3): # e.g., ['Positive', 'Negative', 'Neutral']
        super(MultiTaskSentenceTransformer, self).__init__()

        # Shared sentence encoder
        self.encoder = SentenceTransformer(model_name)
        self.embedding_dim = self.encoder.get_sentence_embedding_dimension()

        # Task A: Sentence classification head
        self.task_a_head = nn.Sequential(
            nn.Linear(self.embedding_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, task_a_num_classes)
        )

        # Task B: Sentiment analysis head
        self.task_b_head = nn.Sequential(
            nn.Linear(self.embedding_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, task_b_num_classes)
        )

    def forward(self, sentences, task='A'):
        # Encode sentences to embeddings (batch-wise)
        embeddings = self.encoder.encode(sentences, convert_to_tensor=True)

        if task == 'A':
            return self.task_a_head(embeddings)
        elif task == 'B':
            return self.task_b_head(embeddings)
        else:
            raise ValueError("Invalid task. Use 'A' or 'B'.")

