In [1]:
import torch
from torch import nn
from transformers import DistilBertModel, DistilBertTokenizer
from torch.optim import AdamW

In [2]:
# Multi-Task Model
class MultiTaskSentenceTransformer(nn.Module):
    def __init__(self, model_name="distilbert-base-uncased", num_classes_task_a=3, num_classes_task_b=4):
        super(MultiTaskSentenceTransformer, self).__init__()
        # Load the pre-trained transformer model and tokenizer
        self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
        self.model = DistilBertModel.from_pretrained(model_name)
        
        # Task A: Sentence Classification (Positive, Negative, Neutral)
        self.task_a_classifier = nn.Linear(self.model.config.hidden_size, num_classes_task_a)
        
        # Task B: Sentiment Analysis (Happy, Sad, Angry, Neutral)
        self.task_b_classifier = nn.Linear(self.model.config.hidden_size, num_classes_task_b)
        
    def forward(self, sentences):
        # Tokenize the sentences
        inputs = self.tokenizer(sentences, padding=True, truncation=True, return_tensors="pt")
        
        # Forward pass through the transformer model
        outputs = self.model(**inputs)
        
        # Extract the last hidden state (token embeddings)
        token_embeddings = outputs.last_hidden_state 
        
        # Apply mean pooling (mean of token embeddings across the sequence length dimension)
        attention_mask = inputs["attention_mask"]
        masked_token_embeddings = token_embeddings * attention_mask.unsqueeze(-1)
        sentence_embeddings = masked_token_embeddings.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
        
        # Task A: Sentence Classification
        task_a_output = self.task_a_classifier(sentence_embeddings)
        
        # Task B: Sentiment Analysis
        task_b_output = self.task_b_classifier(sentence_embeddings)
        
        return task_a_output, task_b_output

In [3]:
#Layer wise learning rate implementation
def layerwise_optimizer(model, lower_lr=1e-6, middle_lr=1e-5, higher_lr=1e-4):
    param_groups = []
    
    # Lower layers of the transformer
    initial_layers = model.model.embeddings.parameters()
    param_groups.append({"params": initial_layers, "lr": lower_lr})
    
    # Middle layers of the transformer
    middle_layers = model.model.transformer.layer[:6].parameters()  # Use the first 6 layers of DistilBERT
    param_groups.append({"params": middle_layers, "lr": middle_lr})
    
    # Upper layers of the transformer (task-specific)
    upper_layers = model.model.transformer.layer[6:].parameters()  # Use the last layers of DistilBERT
    param_groups.append({"params": upper_layers, "lr": higher_lr})
    
    # Task-specific heads
    task_a_head = model.task_a_classifier.parameters()
    task_b_head = model.task_b_classifier.parameters()
    param_groups.append({"params": task_a_head, "lr": higher_lr})
    param_groups.append({"params": task_b_head, "lr": higher_lr})
    
    # Create the optimizer using AdamW
    optimizer = AdamW(param_groups)
    
    return optimizer

In [4]:
model = MultiTaskSentenceTransformer(model_name="distilbert-base-uncased", num_classes_task_a=3, num_classes_task_b=4)
optimizer = layerwise_optimizer(model, lower_lr=1e-6, middle_lr=1e-5, higher_lr=1e-4)

Layer wise learning rate implementation can optimize computational costs while retaining accuracy. By using a pretrained model, such as distilbert-base-uncased, as the backbone of the model, we can decrease the learning rate for the early layers of the model. These layers capture general language features, which have already been trained into the distilbert-base-uncased model. The middle layers capture more task relevant features, so a higher learning rate is used. The final layers consist of the task specific heads, which are optimized by using the highest learning rate. By implementing layerwise learning rates, we can save computational costs by avoiding retraining the model with redundant information. Learning rate is increased only in layers that adapt the model to the specific tasks at hand. 