In [1]:
import torch
from torch import nn
from transformers import DistilBertModel, DistilBertTokenizer

In [2]:
# Multi-Task Model
class MultiTaskSentenceTransformer(nn.Module):
    def __init__(self, model_name="distilbert-base-uncased", num_classes_task_a=3, num_classes_task_b=4):
        super(MultiTaskSentenceTransformer, self).__init__()
        # Load the pre-trained transformer model and tokenizer
        self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
        self.model = DistilBertModel.from_pretrained(model_name)
        
        # Task A: Sentence Classification (Positive, Negative, Neutral)
        self.task_a_classifier = nn.Linear(self.model.config.hidden_size, num_classes_task_a)
        
        # Task B: Sentiment Analysis (Happy, Sad, Angry, Neutral)
        self.task_b_classifier = nn.Linear(self.model.config.hidden_size, num_classes_task_b)
        
    def forward(self, sentences):
        # Tokenize the sentences
        inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
        
        # Forward pass through the transformer model
        outputs = self.model(**inputs)
        
        # Extract the last hidden state (token embeddings)
        token_embeddings = outputs.last_hidden_state 
        
        # Apply mean pooling
        attention_mask = inputs["attention_mask"]
        masked_token_embeddings = token_embeddings * attention_mask.unsqueeze(-1)
        sentence_embeddings = masked_token_embeddings.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
        
        # Task A: Sentence Classification
        task_a_output = self.task_a_classifier(sentence_embeddings)
        
        # Task B: Sentiment Analysis
        task_b_output = self.task_b_classifier(sentence_embeddings)
        
        return task_a_output, task_b_output

In [3]:
model = MultiTaskSentenceTransformer(model_name="distilbert-base-uncased", num_classes_task_a=3, num_classes_task_b=4)

This model utilizes a similar architecture with the distilbert-base-uncased backbone; however each task has its own output layer. This allows the model to learn task specific patterns while having the shared representation of the distilbert-base-uncased backbone. 