# 02 Multi-Task Transformer Demo

This notebook demonstrates:
1. A multi-task model with a shared backbone.
2. Sentence-level classification.
3. Token-level classification (e.g., NER).


In [None]:
import torch
from transformers import AutoTokenizer
from src.multitask_model import MultiTaskSentenceTransformer

# Initialize multi-task model
model_name = 'distilbert-base-uncased'
model = MultiTaskSentenceTransformer(
    model_name=model_name,
    num_classes_task_a=3,  # e.g., 3 possible classes
    num_labels_task_b=5,   # e.g., 5 labels for NER
    pooling='mean'
)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name)

sentences = [
    "Barack Obama was the 44th President of the United States.",
    "I love exploring advanced transformer architectures."
]

encoded = tokenizer(
    sentences,
    padding=True,
    truncation=True,
    max_length=32,
    return_tensors='pt'
)

with torch.no_grad():
    outputs = model(encoded['input_ids'], encoded['attention_mask'])

task_a_logits = outputs['task_a_logits']
task_b_logits = outputs['task_b_logits']

print('Task A logits:', task_a_logits)
print('Task B logits shape:', task_b_logits.shape)

# Inspect tokens for token-level classification
tokens_batch = [tokenizer.convert_ids_to_tokens(ids) for ids in encoded['input_ids']]
for i, tokens in enumerate(tokens_batch):
    print(f"\nSentence {i+1} tokens:", tokens)
    print("Task B logits for each token:", task_b_logits[i].shape)
