In [5]:
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn
import torch.nn.functional as F

class MultiTaskSentenceTransformer(nn.Module):
    def __init__(self, model_name='distilbert-base-uncased', num_labels_task_a=3, num_labels_task_b=2):
        super(MultiTaskSentenceTransformer, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.encoder = AutoModel.from_pretrained(model_name)

        hidden_size = self.encoder.config.hidden_size
        self.pooling = 'mean'

        self.classifier_task_a = nn.Linear(hidden_size, num_labels_task_a)  # Task A: sentence classification
        self.classifier_task_b = nn.Linear(hidden_size, num_labels_task_b)  # Task B: sentiment analysis

    def mean_pooling(self, outputs, attention_mask):
        token_embeddings = outputs.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
        sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        return sum_embeddings / torch.clamp(sum_mask, min=1e-9)

    def forward(self, sentences):
        inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
        outputs = self.encoder(**inputs)
        sentence_embeddings = self.mean_pooling(outputs, inputs['attention_mask'])
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

        logits_task_a = self.classifier_task_a(sentence_embeddings)
        logits_task_b = self.classifier_task_b(sentence_embeddings)

        return logits_task_a, logits_task_b, sentence_embeddings

if __name__ == '__main__':
    # Instantiate model
    model = MultiTaskSentenceTransformer()
    model.eval()

    # Sample sentences 
    sentences = [
        "It's sunny today.",                                      # Task A: Weather, Task B: Positive
        "This laptop has great battery life.",                    # Task A: Technology, Task B: Positive
        "The football match was disappointing.",                  # Task A: Sports, Task B: Negative
        "AI is revolutionizing healthcare and finance.",          # Task A: Technology, Task B: Positive
        "Rain dampened the final leg of the cycling tournament."  # Task A: Sports, Task B: Negative
    ]

    # Label reference (for testing later)
    # Task A (sentence classification): 0 = Technology, 1 = Weather, 2 = Sports
    # Task B (sentiment): 0 = Negative, 1 = Positive

    # Run inference
    with torch.no_grad():
        logits_a, logits_b, embeddings = model(sentences)
        preds_a = torch.argmax(logits_a, dim=1)
        preds_b = torch.argmax(logits_b, dim=1)

    for i, sentence in enumerate(sentences):
        print(f"\nSentence: {sentence}")
        print(f"Predicted Topic (Task A): {preds_a[i].item()} | Predicted Sentiment (Task B): {preds_b[i].item()}")
        print(f"Embedding preview: {embeddings[i][:5].numpy()} ...")  # First 5 dimensions



Sentence: It's sunny today.
Predicted Topic (Task A): 1 | Predicted Sentiment (Task B): 0
Embedding preview: [ 0.0066951  -0.02386935  0.00145519  0.00737709  0.0201737 ] ...

Sentence: This laptop has great battery life.
Predicted Topic (Task A): 2 | Predicted Sentiment (Task B): 0
Embedding preview: [-0.00542451 -0.02068131  0.02674492  0.01063831  0.03491297] ...

Sentence: The football match was disappointing.
Predicted Topic (Task A): 1 | Predicted Sentiment (Task B): 1
Embedding preview: [-0.00851533 -0.03751256 -0.02684303  0.0094255  -0.02344177] ...

Sentence: AI is revolutionizing healthcare and finance.
Predicted Topic (Task A): 2 | Predicted Sentiment (Task B): 1
Embedding preview: [ 0.00240536 -0.00465826  0.0050031   0.0310524   0.03569923] ...

Sentence: Rain dampened the final leg of the cycling tournament.
Predicted Topic (Task A): 1 | Predicted Sentiment (Task B): 1
Embedding preview: [-0.01912113 -0.03690866  0.01899971  0.00336774  0.01271178] ...
