## Fine-tune a Bert model to identify specific topics

In [None]:
import torch, torch.nn as nn
from transformers import (
    BertModel, BertTokenizer, BertPreTrainedModel,
    DataCollatorWithPadding, TrainingArguments, Trainer
)
from sklearn.preprocessing import LabelEncoder
from datasets import Dataset
import pandas as pd, numpy as np
from sklearn.metrics import accuracy_score, f1_score


def compute_metrics(pred):
    logits = pred.predictions
    labels = pred.label_ids

    # Split predictions
    impact_logits, urgency_logits, resource_logits = logits
    impact_labels = labels["impact_type_id"]
    urgency_labels = labels["urgency_id"]
    resource_labels = torch.stack([labels[col] for col in resource_cols], dim=1).numpy()

    impact_preds = impact_logits.argmax(axis=1)
    urgency_preds = urgency_logits.argmax(axis=1)
    resource_preds = (torch.sigmoid(torch.tensor(resource_logits)) > 0.5).int().numpy()

    return {
        "impact_acc": accuracy_score(impact_labels, impact_preds),
        "urgency_acc": accuracy_score(urgency_labels, urgency_preds),
        "resource_f1_micro": f1_score(resource_labels, resource_preds, average="micro"),
        "resource_f1_macro": f1_score(resource_labels, resource_preds, average="macro"),
    }
# ──────────────────────────────────────────────────────────────
# 0.  DEVICE
# ──────────────────────────────────────────────────────────────

device = torch.device("mps")  # Apple‑silicon GPU
print("Using device:", device)

# ──────────────────────────────────────────────────────────────
# 1.  MULTI‑TASK MODEL
# ──────────────────────────────────────────────────────────────

class MultiTaskBERT(BertPreTrainedModel):
    """BERT backbone + 3 classification heads (impact, urgency, resources)."""

    def __init__(self, config, num_impact: int, num_urgency: int, num_resource: int):
        super().__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(0.1)
        self.impact_head = nn.Linear(config.hidden_size, num_impact)
        self.urgency_head = nn.Linear(config.hidden_size, num_urgency)
        self.resource_head = nn.Linear(config.hidden_size, num_resource)
        self.init_weights()

    # **kwargs swallows any extra keys Trainer might pass (impact_type_id …)
    def forward(self, input_ids=None, attention_mask=None, **kwargs):
        pooled = self.dropout(self.bert(input_ids, attention_mask).pooler_output)
        return {
            "impact":   self.impact_head(pooled),
            "urgency":  self.urgency_head(pooled),
            "resource": self.resource_head(pooled),
        }

# ──────────────────────────────────────────────────────────────
# 2.  LOAD DATA & ENCODE LABELS
# ──────────────────────────────────────────────────────────────

df_train = pd.read_pickle("./for_bert/expanded_train.pkl")
df_val   = pd.read_pickle("./for_bert/expanded_val.pkl")

impact_enc  = LabelEncoder()
urgency_enc = LabelEncoder()

df_train["impact_type_id"] = impact_enc.fit_transform(df_train["impact_type"])
df_val["impact_type_id"]   = impact_enc.transform(df_val["impact_type"])

df_train["urgency_id"] = urgency_enc.fit_transform(df_train["urgency"])
df_val["urgency_id"]   = urgency_enc.transform(df_val["urgency"])

resource_cols = [c for c in df_train.columns if c.startswith("resource_")]

cols = ["input_ids", "impact_type_id", "urgency_id"] + resource_cols

# build HF datasets
train_ds = Dataset.from_pandas(df_train[cols]).with_format("torch", columns=cols)
val_ds   = Dataset.from_pandas(df_val[cols]).with_format("torch", columns=cols)

print(train_ds.column_names)

# ──────────────────────────────────────────────────────────────
# 3.  TOKENIZER & COLLATOR (dynamic padding)
# ──────────────────────────────────────────────────────────────

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
collator  = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8, return_tensors="pt")

# ──────────────────────────────────────────────────────────────
# 4.  BUILD MODEL (frozen backbone for quick test)
# ──────────────────────────────────────────────────────────────

model = MultiTaskBERT.from_pretrained(
    "bert-base-uncased",
    num_impact=len(impact_enc.classes_),
    num_urgency=len(urgency_enc.classes_),
    num_resource=len(resource_cols),
).to(device)

#for p in model.bert.parameters():  # comment out to fine‑tune full model
#    p.requires_grad = False

# ──────────────────────────────────────────────────────────────
# 5.  CUSTOM TRAINER
# ──────────────────────────────────────────────────────────────

class MultiTaskTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # move tensors to correct device
        inputs = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()}

        y_impact   = inputs.pop("impact_type_id")
        y_urgency  = inputs.pop("urgency_id")
        y_resource = torch.stack([inputs.pop(col) for col in resource_cols], dim=1).float()

        outputs = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
        )

        # Weighted loss components
        loss_impact   = nn.CrossEntropyLoss()(outputs["impact"], y_impact)
        loss_urgency  = nn.CrossEntropyLoss()(outputs["urgency"], y_urgency)
        loss_resource = nn.BCEWithLogitsLoss()(outputs["resource"], y_resource)

        # You can adjust these weights based on validation later
        total_loss = 1.0 * loss_impact + 1.0 * loss_urgency + 0.5 * loss_resource

        return (total_loss, (outputs["impact"].detach().cpu().numpy(),
                     outputs["urgency"].detach().cpu().numpy(),
                     outputs["resource"].detach().cpu().numpy())) if return_outputs else total_loss


# ──────────────────────────────────────────────────────────────
# 6.  TRAINING ARGS
# ──────────────────────────────────────────────────────────────

args = TrainingArguments(
    output_dir="./bert_multitask_model",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=5,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_dir="./logs",
    report_to="none",
    remove_unused_columns=False,   #<- keep label columns
)

# ──────────────────────────────────────────────────────────────
# 7.  TRAIN
# ──────────────────────────────────────────────────────────────

trainer = MultiTaskTrainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics
)

trainer.train()

# ──────────────────────────────────────────────────────────────
# 8.  SAVE MODEL
# ──────────────────────────────────────────────────────────────

trainer.save_model("./for_bert/bert_multitask_model/final")
tokenizer.save_pretrained("./for_bert/bert_multitask_model/final")

In [None]:
pip install scikit-learn
