In [8]:
from transformers import DistilBertTokenizer, DistilBertModel
import torch.nn as nn
import torch


In [9]:
# Fake sentences
sentences = [
    "Artificial Intelligence is evolving rapidly.",  # Tech (Task A), Positive (Task B)
    "The government passed a new healthcare law.",    # Politics (Task A), Neutral (Task B)
]

# Tokenize
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")

class SentenceTransformer(nn.Module):
    def __init__(self, model_name='distilbert-base-uncased'):
        super(SentenceTransformer, self).__init__()
        self.bert = DistilBertModel.from_pretrained(model_name)
        
    def forward(self, input_ids, attention_mask):
        output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # Mean pooling
        last_hidden_state = output.last_hidden_state
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size())
        sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
        sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
        return sum_embeddings / sum_mask

class MultiTaskSentenceTransformer(nn.Module):
    def __init__(self, model_name='distilbert-base-uncased', 
                 num_classes_task_a=3, num_classes_task_b=3):
        super(MultiTaskSentenceTransformer, self).__init__()
        self.encoder = SentenceTransformer(model_name)
        hidden_size = 768  # DistilBERT hidden size

        # Task-specific heads
        self.classification_head = nn.Linear(hidden_size, num_classes_task_a)
        self.sentiment_head = nn.Linear(hidden_size, num_classes_task_b)

    def forward(self, input_ids, attention_mask, task='A'):
        embeddings = self.encoder(input_ids, attention_mask)
        if task == 'A':
            return self.classification_head(embeddings)
        elif task == 'B':
            return self.sentiment_head(embeddings)
        else:
            raise ValueError("Task must be 'A' or 'B'")

# Initialize model
mtl_model = MultiTaskSentenceTransformer()



In [10]:

# Forward pass for Task A (Classification)
logits_task_a = mtl_model(inputs['input_ids'], inputs['attention_mask'], task='A')
print("Task A (Classification) Logits:\n", logits_task_a)

Task A (Classification) Logits:
 tensor([[ 0.0254, -0.3884,  0.2180],
        [ 0.1403, -0.1674,  0.0340]], grad_fn=<AddmmBackward0>)


In [11]:

# Forward pass for Task B (Sentiment Analysis)
logits_task_b = mtl_model(inputs['input_ids'], inputs['attention_mask'], task='B')
print("Task B (Sentiment) Logits:\n", logits_task_b)


Task B (Sentiment) Logits:
 tensor([[ 0.2342,  0.1590, -0.2486],
        [ 0.1502,  0.1129,  0.1144]], grad_fn=<AddmmBackward0>)
