In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

# TASK 1 #

In [27]:

class SentenceTransformer(torch.nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super(SentenceTransformer, self).__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def forward(self, sentences):
        # automatic padding and truncation when processing inputs of different sizes
        # returns as pytorch tensors
        inputs = self.tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
        outputs = self.model(**inputs)

        # attention mask to ignore the padding
        attention_mask = inputs['attention_mask'].unsqueeze(-1)
        token_embeddings = outputs.last_hidden_state

        # mean pooling gives avg of vector embeddings, good for looking at similarities, captures context of entire chunk
        summed = torch.sum(token_embeddings * attention_mask, dim=1)
        counts = torch.clamp(attention_mask.sum(dim=1), min=1e-9)
        mean_pooled = summed / counts

        # L2 normalization for our loss function later
        embeddings = F.normalize(mean_pooled, p=2, dim=1)
        return embeddings


In [24]:
model = SentenceTransformer()
model.eval()

sentences = [
    "My favorite show right now is Twin Peaks.",
    "David Lynch was the one who wrote Blue Velvet",
    "Twin Peaks was written by David Lynch and Mark Frost.",
    "Every finite language is regular, but not every regular language is finite.",
    "There are many applications for Context Free Grammars",
]

with torch.no_grad():
    embeddings = model(sentences)

print("Embeddings shape:", embeddings.shape[1])
print(embeddings)


Embeddings shape: 768
tensor([[ 0.0135, -0.0153, -0.0146,  ...,  0.0093,  0.0225, -0.0329],
        [ 0.0243, -0.0455, -0.0531,  ..., -0.0036,  0.0349, -0.0083],
        [ 0.0696, -0.0363, -0.0494,  ...,  0.0026,  0.0141, -0.0053],
        [-0.0223, -0.0005, -0.0050,  ..., -0.0093, -0.0485,  0.0680],
        [-0.0006,  0.0084, -0.0038,  ..., -0.0211, -0.0363,  0.0531]])




# TASK 2 #


The tasks I want to implement here is sentence classification and Named Entity Recognition

In [28]:
class MultiTaskModel(nn.Module):
    def __init__(self, model_name, num_classes, num_ner_labels):
        super(MultiTaskModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_name)
        hidden_size = self.model.config.hidden_size

        
        # create two classifier heads to handle multitask
        self.classifier = nn.Linear(hidden_size, num_classes)          # For sentence classification
        self.ner_classifier = nn.Linear(hidden_size, num_ner_labels)    # For token-level NER

    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)

        # instead of just mean pooling to get fixed size emebeddings, get two different outputs
        # for the two tasks
        
        # CLS is better for sentence classification, the full token sequence is good for NER
        
        sequence_output = outputs.last_hidden_state         # (batch_size, seq_len, hidden_dim)
        cls_output = sequence_output[:, 0, :]               # cls token at position 0
        
        class_logits = self.classifier(cls_output)          # Sentence classification output
        ner_logits = self.ner_classifier(sequence_output)   # NER output (token-wise)
        
        return class_logits, ner_logits

    def compute_loss(self, class_logits, ner_logits, class_labels, ner_labels, 
                 classification_weight=1.0, ner_weight=1.0, ner_ignore_index=-100):
        
        # Classification loss
        classification_loss = F.cross_entropy(class_logits, class_labels)
        
        # NER loss
        ner_logits_flat = ner_logits.view(-1, ner_logits.size(-1))
        ner_labels_flat = ner_labels.view(-1)
        ner_loss = F.cross_entropy(ner_logits_flat, ner_labels_flat, ignore_index=ner_ignore_index)
    
        # weighted joint loss, default is evenly weighted but option to change depending on finetuning needs or dataset specifics
        joint_loss = classification_weight * classification_loss + ner_weight * ner_loss
        
        return joint_loss, classification_loss, ner_loss


In [30]:

sentences = [
    "My favorite show right now is Twin Peaks.",
    "David Lynch was the one who wrote Blue Velvet",
    "Twin Peaks was written by David Lynch and Mark Frost.",
    "Every finite language is regular, but not every regular language is finite.",
    "There are many applications for Context Free Grammars",
]

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = MultiTaskModel("bert-base-uncased", num_classes=5, num_ner_labels=9)

inputs = tokenizer(sentences, return_tensors="pt", padding=True, truncation=True)
class_logits, ner_logits = model(inputs['input_ids'], inputs['attention_mask'])

print(class_logits)  # (batch_size, num_classes)
print(ner_logits)    # (batch_size, seq_len, num_ner_labels)




tensor([[ 0.1679, -0.5082,  0.3046, -0.4215, -0.2039],
        [ 0.1733, -0.2227, -0.0356, -0.3514, -0.2043],
        [ 0.2149, -0.2663,  0.2406, -0.4135, -0.3215],
        [-0.0585, -0.2945,  0.4251, -0.3147, -0.3437],
        [ 0.1383, -0.1643,  0.4067, -0.2555, -0.0611]],
       grad_fn=<AddmmBackward0>)
tensor([[[-8.6707e-01,  2.3506e-01,  1.0435e-01,  2.0533e-01,  2.6958e-02,
          -1.4714e-01,  1.8400e-01, -5.3831e-01, -1.4410e-01],
         [-3.2364e-01,  8.0927e-02, -2.2835e-02,  5.0003e-02,  2.0383e-01,
          -2.9610e-01,  6.3430e-02, -2.2567e-02, -9.7748e-02],
         [ 2.3590e-01,  4.7415e-01, -1.0983e-01, -2.9346e-02,  2.6646e-01,
           3.7116e-02,  1.1024e-01, -5.9155e-02,  6.4788e-01],
         [-3.6942e-01,  2.9288e-01, -3.5211e-01, -5.6537e-01,  6.8895e-02,
           3.5708e-01,  2.6148e-03,  2.6933e-01, -1.7191e-03],
         [ 4.3264e-01,  1.1943e-01, -2.1869e-01,  1.1204e-01, -7.1670e-01,
           1.6291e-01, -2.1125e-02, -1.3089e-01, -2.6327e-01],
 

# TASK 3 #

# Discuss the implications and advantages of each scenario and make sure to explain your rationale as to how the model should be trained given the following: #

### If the entire network was frozen, ###

then we are relying entirely on the weights that are given on the pretrained model. The advantage would be that training is extremely quick and very low in computational costs, but the disadvantage would be that it can't be trained to do anything new. The rationale behind this is if the task we want to do is in a very similar domain to the current network, so it would already perform well enough for our needs.

## If we only froze the transformer backbone, ##

then we would only be training the task specific heads by leaving them unfrozen during training. The training would be relatively quick and we are able to train the task specific heads to our liking. The rationale behind this would be if our task is related but not identical to the data the transformer backbone was trained on, and also if we don't have very large amounts of data to be worth training the entire network for.

## If we froze only one task-specific head, ##

then only one head would continue learning depending on what we chose to not freeze. The advantage and rationale behind this is if one task is already performing well enough, then we can protect/preserve its performance while improving the performance of the other task specific head.

#  Consider a scenario where transfer learning can be beneficial. Explain how you would approach the transfer learning process

## Choice of pretrained model ##

The choice of the pretrained model would depend on what domain the task we are trying to complete is in. Lets say we want to train a task-specific head that can do Named Entity Recognition on the names of medicines. We should choose something that is related to that domain, so we can leverage off of a pretrained model like Med-BERT that is trained on medical texts already

## Layers to freeze/unfreeze ## 

Unless we have a large enough amount of data to train on to justify unfreezing layers of the pretrained model, we would want to keep the backbone frozen while we unfreeze the task-specific head in order for it to learn. One of the main reasons why transfer learning is beneficial is that it is a solution when it comes to the problem of getting large enough datasets. By utilizing the fact that the pretrained models are trained on much larger datasets beforehand, we can speed up our training and rely on a smaller dataset without worrying too much about overfitting.





# TASK 4

#### Here would be a hypothetical training loop for our model, assuming we have a real dataset on hand and we created a dataset class to be loaded using the 

#### DataLoader. This is probably just a bit more than a bare skeleton for a training loop, as we can also add stuff like scheduling for the learning rate or even 

#### utilize something like Weights and Biases/WandB to make searching and finetuning the hyperparameters very easy and convenient.

In [None]:
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, classification_report



def train_mtl_model(model, train_dataloader, val_dataloader, optimizer, device, 
                    num_epochs, classification_weight=1.0, ner_weight=1.0):
    """
    Args:
        model: MultiTaskModel instance
        train_dataloader: DataLoader for training data
        val_dataloader: DataLoader for validation data
        optimizer: torch.optim optimizer
        device: 'cuda' or 'cpu'
        num_epochs: number of training epochs
        classification_weight: weight for classification loss
        ner_weight: weight for NER loss
    """
    
    # Move model to device
    model = model.to(device) # cuda if available if not then CPU
    
    best_val_loss = float('inf')
    trigger_times = 0 
    patience = 5    # early stopping
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        
        total_train_loss = 0
        train_class_correct = 0
        train_class_total = 0
        train_ner_correct = 0
        train_ner_total = 0
        
        # For accumulating true/pred labels across batches for macro metrics, important because we have multiple tasks, so we can see if 
        #  a certain task is performing better/worse than the other
        all_class_true = []
        all_class_pred = []
        all_ner_true = []
        all_ner_pred = []

        #         tqdm just for visual clarity although im not even actually training it
        for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1} Training"):
            
            # Unpack batch data assuming the data is in this format
            # batch = {input_ids, attention_mask, class_labels, ner_labels}
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            class_labels = batch['class_labels'].to(device)
            ner_labels = batch['ner_labels'].to(device)

            
            optimizer.zero_grad()  # reset gradients from previous batch
            
            #Forward pass returns two outputs since it computes both classification and NER outputs
            class_logits, ner_logits = model(input_ids, attention_mask)
            
            # Compute joint multi-task loss
            loss, class_loss, ner_loss = model.compute_loss(
                class_logits, ner_logits, class_labels, ner_labels,
                classification_weight, ner_weight
            )
            
            # Backward pass, compute gradients
            loss.backward()  # Backpropagate the joint loss
            
            # Optimizer step, update model parameters
            optimizer.step()  

            # Accumulate loss
            total_train_loss += loss.item()
            
            # Calculate training accuracy for both tasks
            # Classification task
            _, class_preds = torch.max(class_logits, 1)
            train_class_correct += (class_preds == class_labels).sum().item()
            train_class_total += class_labels.size(0)
            
            # NER task (ignore padded tokens with label -100)
            ner_preds = torch.argmax(ner_logits, dim=-1)
            active_tokens = (ner_labels != -100)  # mask for non-padded tokens
            train_ner_correct += ((ner_preds[active_tokens] == ner_labels[active_tokens]).sum().item())
            train_ner_total += active_tokens.sum().item()
            
            # Store for macro metrics
            all_class_true.extend(class_labels.cpu().numpy())
            all_class_pred.extend(class_preds.cpu().numpy())
            
            # For NER, only consider non-padded tokens
            masked_ner_true = ner_labels[active_tokens].cpu().numpy()
            masked_ner_pred = ner_preds[active_tokens].cpu().numpy()
            all_ner_true.extend(masked_ner_true)
            all_ner_pred.extend(masked_ner_pred)

        ## Here we can begin to do the validation part
        val_metrics = {
            'loss': 0.0,
            'class_correct': 0,
            'class_total': 0,
            'ner_correct': 0,
            'ner_total': 0,
            'class_true': [],
            'class_pred': [],
            'ner_true': [],
            'ner_pred': []
        }
        
        with torch.no_grad():  # validation part of the loop
            for batch in tqdm(val_dataloader, desc="Validating"):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                class_labels = batch['class_labels'].to(device)
                ner_labels = batch['ner_labels'].to(device)
                
                # forward pass no backprop
                class_logits, ner_logits = model(input_ids, attention_mask)
                
                # Loss calculation
                loss, _, _ = model.compute_loss(
                    class_logits, ner_logits, class_labels, ner_labels,
                    classification_weight, ner_weight
                )
                val_metrics['loss'] += loss.item()
                
                # Classification metrics
                _, class_preds = torch.max(class_logits, 1)
                val_metrics['class_correct'] += (class_preds == class_labels).sum().item()
                val_metrics['class_total'] += class_labels.size(0)
                
                # NER metrics
                ner_preds = torch.argmax(ner_logits, dim=-1)
                active_tokens = (ner_labels != -100)
                val_metrics['ner_correct'] += ((ner_preds[active_tokens] == ner_labels[active_tokens]).sum().item())
                val_metrics['ner_total'] += active_tokens.sum().item()
                
                # Store for macro metrics
                val_metrics['class_true'].extend(class_labels.cpu().numpy())
                val_metrics['class_pred'].extend(class_preds.cpu().numpy())
                val_metrics['ner_true'].extend(ner_labels[active_tokens].cpu().numpy())
                val_metrics['ner_pred'].extend(ner_preds[active_tokens].cpu().numpy())
        
        # Calculate validation metrics
        avg_val_loss = val_metrics['loss'] / len(val_dataloader)
        val_class_acc = val_metrics['class_correct'] / val_metrics['class_total']
        val_ner_acc = val_metrics['ner_correct'] / val_metrics['ner_total']
        val_class_f1 = f1_score(val_metrics['class_true'], val_metrics['class_pred'], average='macro')
        val_ner_f1 = f1_score(val_metrics['ner_true'], val_metrics['ner_pred'], average='macro')
        
        # Print epoch stats
        print(f"\nEpoch {epoch + 1} Validation Results:")
        print(f"Val Loss: {avg_val_loss:.4f} | Class Acc: {val_class_acc:.4f} | NER Acc: {val_ner_acc:.4f}")
        print(f"Class F1: {val_class_f1:.4f} | NER F1: {val_ner_f1:.4f}")

        
        # Save best model and/or
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            trigger_times = 0
            torch.save(model.state_dict(), 'best_mtl_model.pt')
            print("New best model saved!")
        else:
            trigger_times += 1
            if trigger_times > patience:
                break