# 🧠 Multitask Transformer: News Summarization + Categorization
This notebook demonstrates how to build and train a multitask transformer model that performs both news **summarization** and **categorization** using a shared encoder architecture.

In [1]:
import torch
from torch import nn
from transformers import BartTokenizer, BartModel, BartConfig
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score
import random
from tqdm import tqdm

In [2]:
xsum = load_dataset('xsum', split='train[:1%]')
ag_news = load_dataset('ag_news', split='train[:1%]')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

In [3]:
len(ag_news), len(xsum)

(1200, 2040)

In [4]:
class MultiTaskDataset(Dataset):
    def __init__(self, xsum_data, ag_data, tokenizer, max_len=512):
        self.xsum_data = xsum_data
        self.ag_data = ag_data
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return max(len(self.xsum_data), len(self.ag_data))

    def __getitem__(self, idx):
        if idx % 2 == 0:  # Summarization task
            item = self.xsum_data[idx % len(self.xsum_data)]
            enc = self.tokenizer(item['document'], truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt')
            dec = self.tokenizer(item['summary'], truncation=True, padding='max_length', max_length=64, return_tensors='pt')
            return {
                'input_ids': enc['input_ids'].squeeze(0),
                'attention_mask': enc['attention_mask'].squeeze(0),
                'labels': dec['input_ids'].squeeze(0),
                'task': 'summarization'
            }
        else:  # Classification task
            item = self.ag_data[idx % len(self.ag_data)]
            enc = self.tokenizer(item['text'], truncation=True, padding='max_length', max_length=self.max_len, return_tensors='pt')
            return {
                'input_ids': enc['input_ids'].squeeze(0),
                'attention_mask': enc['attention_mask'].squeeze(0),
                'label_class': item['label'],
                'task': 'classification'
            }

In [5]:
def multitask_collate(batch):
    # Separate summarization and classification items
    summarization_batch = [item for item in batch if item['task'] == 'summarization']
    classification_batch = [item for item in batch if item['task'] == 'classification']

    collated = {}

    if summarization_batch:
        collated['summarization'] = {
            'input_ids': torch.stack([item['input_ids'] for item in summarization_batch]),
            'attention_mask': torch.stack([item['attention_mask'] for item in summarization_batch]),
            'labels': torch.stack([item['labels'] for item in summarization_batch])
        }

    if classification_batch:
        collated['classification'] = {
            'input_ids': torch.stack([item['input_ids'] for item in classification_batch]),
            'attention_mask': torch.stack([item['attention_mask'] for item in classification_batch]),
            'label_class': torch.tensor([item['label_class'] for item in classification_batch])
        }

    return collated

In [6]:
class MultiTaskBart(nn.Module):
    def __init__(self, config, num_labels):
        super().__init__()
        self.encoder = BartModel(config).get_encoder()
        self.decoder = BartModel(config).get_decoder()
        self.lm_head = nn.Linear(config.d_model, config.vocab_size)
        self.classifier = nn.Linear(config.d_model, num_labels)

    def forward(self, input_ids, attention_mask, decoder_input_ids=None, task='summarization', label_class=None):
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        encoder_hidden_states = encoder_outputs[0]

        if task == 'summarization':
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=attention_mask
            )
            return self.lm_head(decoder_outputs[0])  # (batch, seq, vocab)
        else:
            cls_representation = encoder_hidden_states[:, 0, :]  # use first token
            return self.classifier(cls_representation)  # (batch, num_labels)
        
    def generate(self, input_ids, attention_mask, max_length=64, eos_token_id=2):
        """
        Greedy decoding for summarization
        """
        self.eval()
        batch_size = input_ids.size(0)
        decoder_input_ids = torch.full((batch_size, 1), tokenizer.bos_token_id, dtype=torch.long)

        for _ in range(max_length):
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                encoder_hidden_states=self.encoder(input_ids=input_ids, attention_mask=attention_mask)[0],
                encoder_attention_mask=attention_mask
            )
            logits = self.lm_head(decoder_outputs[0])  # (batch, seq_len, vocab)
            next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)  # (batch, 1)

            decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=-1)

            # Stop decoding when all sequences have generated an <eos>
            if torch.all(next_token.squeeze(-1) == eos_token_id):
                break

        return decoder_input_ids


In [7]:
config = BartConfig.from_pretrained('facebook/bart-base')
model = MultiTaskBart(config, num_labels=4)  # AG News has 4 classes
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
dataset = MultiTaskDataset(xsum, ag_news, tokenizer)
loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=multitask_collate)

In [8]:
num_epochs = 10

In [9]:
model.train()
for epoch in range(num_epochs):
    print(f"\nEpoch: {epoch+1}")
    for batch in tqdm(loader, desc="Training", leave=False):
        optimizer.zero_grad()

        # Summarization
        if 'summarization' in batch:
            input_ids = batch['summarization']['input_ids']
            attention_mask = batch['summarization']['attention_mask']
            labels = batch['summarization']['labels']

            output = model(input_ids, attention_mask, decoder_input_ids=labels, task='summarization')
            loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
            loss = loss_fn(output.view(-1, output.size(-1)), labels.view(-1))
            loss.backward()
            optimizer.step()

        # Classification
        if 'classification' in batch:
            input_ids = batch['classification']['input_ids']
            attention_mask = batch['classification']['attention_mask']
            label_class = batch['classification']['label_class']

            output = model(input_ids, attention_mask, task='classification')
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(output, label_class)
            loss.backward()
            optimizer.step()


Epoch: 1


                                                             


Epoch: 2


                                                             


Epoch: 3


                                                             


Epoch: 4


                                                             


Epoch: 5


                                                             


Epoch: 6


                                                             


Epoch: 7


                                                             


Epoch: 8


                                                             


Epoch: 9


                                                             


Epoch: 10


                                                             

In [10]:
# saving the model
import os

SAVE_PATH = "./multitask_model_checkpoint"
os.makedirs(SAVE_PATH, exist_ok=True)

# Save model weights
torch.save(model.state_dict(), os.path.join(SAVE_PATH, "model.pt"))

# Save optimizer state (optional, for resuming training)
torch.save(optimizer.state_dict(), os.path.join(SAVE_PATH, "optimizer.pt"))

# Save config and tokenizer
config.save_pretrained(SAVE_PATH)
tokenizer.save_pretrained(SAVE_PATH)

print("Model, tokenizer, and optimizer saved!")


Model, tokenizer, and optimizer saved!


Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0}


load the model and retrain it

In [None]:
# # Load tokenizer and config
# from transformers import BartTokenizer, BartConfig

# # SAVE_PATH = "./multitask_model_checkpoint"

# tokenizer = BartTokenizer.from_pretrained(SAVE_PATH)
# config = BartConfig.from_pretrained(SAVE_PATH)

# # Re-initialize model and optimizer
# model = MultiTaskBart(config, num_labels=4)
# model.load_state_dict(torch.load(os.path.join(SAVE_PATH, "model.pt")))

# optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
# optimizer.load_state_dict(torch.load(os.path.join(SAVE_PATH, "optimizer.pt")))

# model.train()



In [11]:
from sklearn.metrics import accuracy_score

model.eval()
all_preds = []
all_labels = []

# Use a DataLoader for test data
test_data = load_dataset('ag_news', split='test[:1%]')
tokenized = [tokenizer(x['text'], padding='max_length', truncation=True, max_length=512, return_tensors='pt') for x in test_data]
labels = [x['label'] for x in test_data]

with torch.no_grad():
    for i, sample in enumerate(tokenized):
        input_ids = sample['input_ids']
        attention_mask = sample['attention_mask']
        output = model(input_ids, attention_mask, task='classification')
        pred = torch.argmax(output, dim=1).item()
        all_preds.append(pred)
        all_labels.append(labels[i])

acc = accuracy_score(all_labels, all_preds)
print(f"Classification Accuracy: {acc:.4f}")


Classification Accuracy: 0.3947


In [12]:
import evaluate
rouge = evaluate.load("rouge")

model.eval()
generated = []
references = []

for sample in test_data:
    inputs = tokenizer(sample['document'], return_tensors='pt', truncation=True, padding='max_length', max_length=512)
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']

    output_ids = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=64)
    summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    generated.append(summary)
    references.append(sample['summary'])

results = rouge.compute(predictions=generated, references=references, use_stemmer=True)
print("ROUGE-1:", results["rouge1"])
print("ROUGE-2:", results["rouge2"])
print("ROUGE-L:", results["rougeL"])


KeyError: 'document'

## Testing

In [None]:
import torch
from transformers import BartTokenizer, BartConfig
import os

# Load tokenizer and config
CHECKPOINT_PATH = "./multitask_model_checkpoint"
tokenizer = BartTokenizer.from_pretrained(CHECKPOINT_PATH)
config = BartConfig.from_pretrained(CHECKPOINT_PATH)

# Define your custom model class again
class MultiTaskBart(torch.nn.Module):
    def __init__(self, config, num_labels):
        super().__init__()
        from transformers import BartModel
        self.encoder = BartModel(config).get_encoder()
        self.decoder = BartModel(config).get_decoder()
        self.lm_head = torch.nn.Linear(config.d_model, config.vocab_size)
        self.classifier = torch.nn.Linear(config.d_model, num_labels)

    def forward(self, input_ids, attention_mask, decoder_input_ids=None, task='summarization', label_class=None):
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        encoder_hidden_states = encoder_outputs[0]

        if task == 'summarization':
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=attention_mask
            )
            return self.lm_head(decoder_outputs[0])
        else:
            cls_representation = encoder_hidden_states[:, 0, :]
            return self.classifier(cls_representation)

    def generate(self, input_ids, attention_mask, max_length=64, eos_token_id=2):
        self.eval()
        batch_size = input_ids.size(0)
        decoder_input_ids = torch.full((batch_size, 1), tokenizer.bos_token_id, dtype=torch.long)

        for _ in range(max_length):
            encoder_hidden = self.encoder(input_ids=input_ids, attention_mask=attention_mask)[0]
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                encoder_hidden_states=encoder_hidden,
                encoder_attention_mask=attention_mask
            )
            logits = self.lm_head(decoder_outputs[0])
            next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=-1)

            if torch.all(next_token.squeeze(-1) == eos_token_id):
                break

        return decoder_input_ids

# Instantiate and load model weights
model = MultiTaskBart(config, num_labels=4)
model.load_state_dict(torch.load(os.path.join(CHECKPOINT_PATH, "model.pt")))
model.eval()
print("✅ Model loaded.")

# ---------------------
# Inference
# ---------------------

text = """Heavy monsoon rains have caused severe flooding in several parts of the country, displacing thousands of people and damaging homes and infrastructure."""

# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)

# ---- Summarization ----
with torch.no_grad():
    summary_ids = model.generate(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=64)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)

# ---- Classification ----
with torch.no_grad():
    logits = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], task="classification")
    predicted_class = torch.argmax(logits, dim=1).item()

label_map = {
    0: "World",
    1: "Sports",
    2: "Business",
    3: "Sci/Tech"
}

print("Summary:", summary)
print("Predicted Category:", label_map[predicted_class])


In [None]:
# !pip install transformers datasets evaluate rouge_score absl-py scikit-learn pandas
# to train the model for custom categories
import torch
import pandas as pd
from transformers import BartTokenizer, BartConfig
from sklearn.metrics import accuracy_score
import evaluate
import os

# ----- PATHS -----
CHECKPOINT_PATH = "./multitask_model_checkpoint"
CUSTOM_DATA_PATH = "./custom_data.csv"  # CSV file with 'text' and 'label' columns

# ----- CUSTOM CATEGORIES -----
custom_labels = ['disaster', 'sports', 'finance', 'entertainment']  # modify as needed
label2id = {label: idx for idx, label in enumerate(custom_labels)}
id2label = {idx: label for label, idx in label2id.items()}
num_labels = len(label2id)

# ----- LOAD MODEL CONFIG + TOKENIZER -----
tokenizer = BartTokenizer.from_pretrained(CHECKPOINT_PATH)
config = BartConfig.from_pretrained(CHECKPOINT_PATH)

# ----- DEFINE MODEL -----
class MultiTaskBart(torch.nn.Module):
    def __init__(self, config, num_labels):
        super().__init__()
        from transformers import BartModel
        self.encoder = BartModel(config).get_encoder()
        self.decoder = BartModel(config).get_decoder()
        self.lm_head = torch.nn.Linear(config.d_model, config.vocab_size)
        self.classifier = torch.nn.Linear(config.d_model, num_labels)

    def forward(self, input_ids, attention_mask, decoder_input_ids=None, task='summarization'):
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        encoder_hidden_states = encoder_outputs[0]

        if task == 'summarization':
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=attention_mask
            )
            return self.lm_head(decoder_outputs[0])
        else:
            cls_representation = encoder_hidden_states[:, 0, :]
            return self.classifier(cls_representation)

    def generate(self, input_ids, attention_mask, max_length=64, eos_token_id=2):
        self.eval()
        batch_size = input_ids.size(0)
        decoder_input_ids = torch.full((batch_size, 1), tokenizer.bos_token_id, dtype=torch.long)
        for _ in range(max_length):
            enc_hidden = self.encoder(input_ids=input_ids, attention_mask=attention_mask)[0]
            decoder_outputs = self.decoder(
                input_ids=decoder_input_ids,
                encoder_hidden_states=enc_hidden,
                encoder_attention_mask=attention_mask
            )
            logits = self.lm_head(decoder_outputs[0])
            next_token = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True)
            decoder_input_ids = torch.cat([decoder_input_ids, next_token], dim=-1)
            if torch.all(next_token.squeeze(-1) == eos_token_id):
                break
        return decoder_input_ids

# ----- LOAD MODEL -----
model = MultiTaskBart(config, num_labels=num_labels)
model.load_state_dict(torch.load(os.path.join(CHECKPOINT_PATH, "model.pt")))
model.eval()

# ----- LOAD CUSTOM DATA -----
df = pd.read_csv(CUSTOM_DATA_PATH)
df = df[df['label'].isin(custom_labels)]  # filter to allowed classes
df['label_id'] = df['label'].map(label2id)

# ----- CLASSIFICATION -----
print("\n📚 Classification on custom dataset...")
all_preds, all_labels = [], []

for _, row in df.iterrows():
    text = row['text']
    label_id = row['label_id']

    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=512)

    with torch.no_grad():
        logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], task="classification")
        pred = torch.argmax(logits, dim=1).item()

    all_preds.append(pred)
    all_labels.append(label_id)

acc = accuracy_score(all_labels, all_preds)
print("✅ Classification Accuracy:", acc)
print("🧾 Predicted Labels:", [id2label[p] for p in all_preds[:5]])

# ----- SUMMARIZATION (OPTIONAL for same input) -----
print("\n📝 Sample summarization (first 3 rows)...")
for _, row in df.head(3).iterrows():
    text = row['text']
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=512)

    with torch.no_grad():
        output_ids = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
        summary = tokenizer.decode(output_ids[0], skip_special_tokens=True)

    print(f"\n📄 Text: {text[:100]}...")
    print(f"📝 Summary: {summary}")
