In [1]:
import numpy as np
import torch
import math
from torch import nn
import torch.nn.functional as F
from tokenizer import Tokenizer

import regex as re
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from enums import CLS_ID, PAD_ID, IGNORE_INDEX, MAX_LEN, DROPOUT

# Sentence Transformer Multi-Task Expansion Overview

## Dataset
* MASSIVE dataset (English split) for multi-task learning, which contains User utterances paired with:

> Intent labels (classification task)

> Slot annotations for Named Entity Recognition (NER) (sequence tagging task)


* Key steps:

> BPE tokenization is performed using a custom tokenizer.

> Slot labels are aligned with BPE pieces, expanding one label across subword units.

> Both input IDs and slot labels are padded/truncated to a maximum length (MAX_LEN).
 

* Oversampling is used during training to give more weight to examples containing at least one NER tag (to handle class imbalance).



## Multi-Task Model Architecture

* The multitask model builds on top of the SentenceTransformer backbone:

* Intent Classifier

> A simple Linear layer applied on the [CLS] token embedding for intent classification.

* NER Tagger

> A 2-layer bidirectional LSTM applied on the output sequence embeddings.

> A 2-layer MLP (LayerNorm → Linear → GELU → Dropout → Linear) classifies each token position.

## Design choice
Initially, only the new heads were trained. After multiple unfreezing backbone experiments, unfreezing the full backbone gave better performance.

## Loss Function

* Intent Loss

> Standard CrossEntropyLoss on the intent classification output.

* NER Loss

> CrossEntropyLoss applied token-wise, with class weighting to emphasize rare NER labels.

> The loss ignores padding and CLS tokens (using IGNORE_INDEX).

#### The final loss is the sum of intent loss and NER loss.

## Training phase details:

> Weighted Random Sampler to balance NER-positive examples.

> Mixed Precision Training (AMP) for faster and memory-efficient training.

> Gradient clipping for stability.

> Cosine Annealing LR scheduler.

> Early stopping based on validation loss.

## Inference
* During inference:

> Input sentences are tokenized and passed through the multitask model.

> The model predicts: Intent label (single prediction). And NER tags for each token position (sequence prediction).

* A simple sanity check shows:

> Intent classification performs reasonably well early on.

> NER tagging performance could improve with larger models, more data, or task-specific pretraining.



In [2]:
class MultiTaskDataset(Dataset):
    def __init__(self, tokenizer_path, split):
        self.tokenizer = Tokenizer(tokenizer_path) # load tokenizer
        self.pat = self.tokenizer.pat # get compiled regex pattern
        self.IGNORE_INDEX = IGNORE_INDEX # used for values with ' ' or PAD or CLS tokens.
        self.CLS_ID = CLS_ID
        self.PAD_ID = PAD_ID
        self.dataset = load_dataset("qanastek/MASSIVE", "en-US", split=split) # Load the english NER+IntentClassification dataset based on split


    def __len__(self):
        return len(self.dataset)


    def __getitem__(self, idx):
        ex = self.dataset[idx]
        
        text  = ex["utt"]
        slots = ex["ner_tags"]
        intent = ex["intent"]
        
        word_spans = [(m.group(), m.span()) for m in re.finditer(r"\S+", text)]        
        pieces_with_spans = [(m.group(), m.span()) for m in re.finditer(self.pat, text)]
        
        piece_labels = []
        for piece, (start, end) in pieces_with_spans:
            for idx, (_, (w0, w1)) in enumerate(word_spans):
                if w0 <= start < w1:
                    piece_labels.append(slots[idx])
                    break
            else:
                piece_labels.append(self.IGNORE_INDEX)
        
        # Expand to BPE tokens
        token_pieces = self.tokenizer.encode(text)
        flat_ids     = [self.CLS_ID]
        flat_labels  = [self.IGNORE_INDEX]
        
        for tok_ids, lab in zip(token_pieces, piece_labels):
            flat_ids.extend(tok_ids)
            flat_labels.extend([lab] * len(tok_ids))
        

        return [flat_ids, flat_labels, intent]


def collate(batch):
    flat_ids, flat_labels, intents = zip(*batch)

    # Compute the joint max length (capped by MAX_LEN)
    raw_max = max(max(len(ids), len(labs)) for ids, labs in zip(flat_ids, flat_labels))
    L = min(raw_max, MAX_LEN)

    # Pad or truncate *both* ids and labels to length L
    input_ids = []
    slot_labels = []
    attention_mask = []
    for ids, labs in zip(flat_ids, flat_labels):
        ids = ids[:L]
        labs = labs[:L]

        pad_len = L - len(ids)
        input_ids.append(ids + [PAD_ID] * pad_len)
        slot_labels.append(labs + [IGNORE_INDEX] * pad_len)
        attention_mask.append([1]*len(ids) + [0]*pad_len)

    # Stack to form tensors
    input_ids = torch.tensor(input_ids, dtype=torch.long)
    attention_mask = torch.tensor(attention_mask, dtype=torch.long)
    slot_labels = torch.tensor(slot_labels, dtype=torch.long)
    intent_labels = torch.tensor(intents, dtype=torch.long)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "slot_labels": slot_labels,
        "intent_labels": intent_labels,
    }


In [3]:
train_ds = MultiTaskDataset("bpe_merged.json", "train")
validation_ds = MultiTaskDataset("bpe_merged.json", "validation")
test_ds = MultiTaskDataset("bpe_merged.json", "test")

In [4]:
# Oversampling data with ner tags because they are under-represneted in the training dataset.
from torch.utils.data import WeightedRandomSampler

pos_indicator = []
for flat_ids, flat_labels, intent in train_ds:
    has_positive = any((lab != IGNORE_INDEX and lab != 0) for lab in flat_labels)
    pos_indicator.append(1 if has_positive else 0)

pos_alpha = 5.0
neg_alpha = 3.0

weights = [pos_alpha if p==1 else neg_alpha for p in pos_indicator]

sampler = WeightedRandomSampler(
    weights,
    num_samples=len(weights),
    replacement=True
)

In [5]:
# Finding unique number of intents and ner tags.
num_intents = len(set(train_ds.dataset["intent"] + validation_ds.dataset["intent"] + test_ds.dataset["intent"]))
num_ner_tags = train_ds.dataset.features["ner_tags"].feature.num_classes + 1

In [6]:
train_loader = DataLoader(
    train_ds,
    batch_size=32,
    sampler=sampler,
    num_workers=0,
    pin_memory=True,
    collate_fn=collate
)
val_loader = DataLoader(
    validation_ds,
    batch_size=32, shuffle=False,
    num_workers=0, pin_memory=True,
    collate_fn=collate
)
test_loader = DataLoader(
    test_ds,
    batch_size=32, shuffle=False,
    num_workers=0, pin_memory=True,
    collate_fn=collate
)

In [7]:
class MultiTaskModel(nn.Module):
    """
    Backbone = pre‐trained SentenceTransformer, plus:
     - intent model: simple linear on [CLS]
     - NER model: 2×BiLSTM → 2‐layer MLP on each timestep
    """
    def __init__(
        self,
        sentence_encoder: nn.Module,
        d_model: int,
        proj_dim: int,
        num_intents: int,
        num_ner_tags: int,
        dropout: float = DROPOUT,
        freeze_encoder: bool = True,
    ):
        super().__init__()
        self.encoder = sentence_encoder

        # intent head
        self.intent_classifier = nn.Linear(proj_dim, num_intents)

        # NER head
        self.ner_lstm = nn.LSTM(
            input_size=d_model,
            hidden_size=d_model//2,
            num_layers=2,
            bidirectional=True,
            batch_first=True,
            dropout=dropout if 2 > 1 else 0.0
        )
        self.ner_mlp = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, num_ner_tags)
        )

    def forward(self, input_ids, attention_mask):
        # encode
        cls_emb, seq_emb = self.encoder(input_ids, attention_mask=attention_mask, return_all=True)
        intent_logits = self.intent_classifier(cls_emb)
        lstm_out, _ = self.ner_lstm(seq_emb)
        ner_logits  = self.ner_mlp(lstm_out)
        return intent_logits, ner_logits


In [8]:
from sent_transformer import SentenceTransformer
from enums import VOCAB_SIZE, D_MODEL, NHEAD, NUM_LAYERS, PROJ_DIM, DROPOUT

# Load PreTrained Sentence Encoder.
encoder = SentenceTransformer(
            vocab_size=VOCAB_SIZE, d_model=D_MODEL, nhead=NHEAD, num_layers=NUM_LAYERS, proj_dim=PROJ_DIM, dropout=DROPOUT
        )

encoder.load_state_dict(torch.load('best_encoder.pt'))

<All keys matched successfully>

In [9]:
import os
from tqdm import tqdm
import torch
from torch.optim import AdamW

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize MultiTaskModel
model = MultiTaskModel(
    sentence_encoder=encoder,
    d_model=D_MODEL,
    proj_dim=PROJ_DIM,
    num_intents=num_intents,
    num_ner_tags=num_ner_tags,
    freeze_encoder=True
).to(device)


# I tried various combinations of layer freezing. ie: freezing the whole backbone, 
# unfreezing the last 2 layers of the backbone, unfreezing the last 2 layers of the backbone along with some encoder layers
# and unfreezing the whole backbone.

# I ended up unfreezing the whole backbone for eval accuracy.

encoder_params = []
for n,p in model.encoder.named_parameters():
    if p.requires_grad:
        encoder_params.append(p)
        
head_params = (
    list(model.intent_classifier.parameters()) +
    list(model.ner_lstm.parameters()) +
    list(model.ner_mlp.parameters())
)


opt = AdamW([
    {'params': encoder_params, 'lr': 5e-5},   # fine‐tune top transformer layers with a really low LR.
    {'params': head_params, 'lr': 1e-3},   # train new heads from scratch with high LR for faster convergence.
], weight_decay=1e-2)


epochs = 500
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)

# AMP setup
use_amp = torch.cuda.is_available()
scaler = torch.cuda.amp.GradScaler() if use_amp else None

# Checkpointing and early stopping setup
checkpoint_dir = "checkpoints_multitask"
os.makedirs(checkpoint_dir, exist_ok=True)

best_val_loss = float('inf')
patience = 5
no_improve = 0

intent_loss_fn = nn.CrossEntropyLoss()

weights = torch.ones(num_ner_tags)
weights *= 3
weights[[0,-1]] = 1
ner_loss_fn = nn.CrossEntropyLoss(weight=weights, ignore_index=IGNORE_INDEX)


# Training
for epoch in range(1, epochs + 1):
    model.train()
    train_loss = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch} Train"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        slot_labels = batch["slot_labels"].to(device)
        intent_labels = batch["intent_labels"].to(device)

        opt.zero_grad()

        if use_amp:
            with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                intent_logits, ner_logits = model(input_ids, attention_mask)
                intent_loss = intent_loss_fn(intent_logits, intent_labels)
                ner_loss = ner_loss_fn(ner_logits.view(-1, num_ner_tags), slot_labels.view(-1))
                loss = intent_loss + ner_loss

            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(opt)
            scaler.update()
        else:
            intent_logits, ner_logits = model(input_ids, attention_mask)
            intent_loss = intent_loss_fn(intent_logits, intent_labels)
            ner_loss = ner_loss_fn(ner_logits.view(-1, num_ner_tags), slot_labels.view(-1))
            loss = intent_loss + ner_loss

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)

    # Validation
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch} Val"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            slot_labels = batch["slot_labels"].to(device)
            intent_labels = batch["intent_labels"].to(device)

            if use_amp:
                with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
                    intent_logits, ner_logits = model(input_ids, attention_mask)
                    intent_loss = intent_loss_fn(intent_logits, intent_labels)
                    ner_loss = ner_loss_fn(ner_logits.view(-1, num_ner_tags), slot_labels.view(-1))
                    loss = intent_loss + ner_loss
            else:
                intent_logits, ner_logits = model(input_ids, attention_mask)
                intent_loss = intent_loss_fn(intent_logits, intent_labels)
                ner_loss = ner_loss_fn(ner_logits.view(-1, num_ner_tags), slot_labels.view(-1))
                loss = intent_loss + ner_loss

            val_loss += loss.item()

    val_loss /= len(val_loader)

    # Learning Rate Schedule
    scheduler.step()
    lr = scheduler.get_last_lr()[0]

    print(f"Epoch {epoch}/{epochs} — "
          f"train_loss={train_loss:.5f}  val_loss={val_loss:.5f}  lr={lr:.1e}")

    # Saving Checkpoints
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': opt.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
    }, os.path.join(checkpoint_dir, f"multitask_epoch_{epoch}.pt"))

    # Early Stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improve = 0
        torch.save(model.state_dict(), "best_multitask_model.pt")
        print("  ↳ New best MultiTaskModel saved.")
    else:
        no_improve += 1
        print(f"  ↳ No improvement for {no_improve} epoch(s).")
        if no_improve >= patience:
            print(f"Stopping early after {patience} epochs without improvement.")
            break


Epoch 1 Train: 100%|██████████████████████████| 360/360 [00:29<00:00, 12.00it/s]
  output = torch._nested_tensor_from_mask(
Epoch 1 Val: 100%|██████████████████████████████| 64/64 [00:02<00:00, 29.82it/s]


Epoch 1/500 — train_loss=6.47528  val_loss=5.66544  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 2 Train: 100%|██████████████████████████| 360/360 [00:31<00:00, 11.59it/s]
Epoch 2 Val: 100%|██████████████████████████████| 64/64 [00:02<00:00, 29.09it/s]


Epoch 2/500 — train_loss=5.44485  val_loss=5.11691  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 3 Train: 100%|██████████████████████████| 360/360 [00:30<00:00, 11.65it/s]
Epoch 3 Val: 100%|██████████████████████████████| 64/64 [00:02<00:00, 27.75it/s]


Epoch 3/500 — train_loss=4.99013  val_loss=4.79890  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 4 Train: 100%|██████████████████████████| 360/360 [00:38<00:00,  9.30it/s]
Epoch 4 Val: 100%|██████████████████████████████| 64/64 [00:03<00:00, 20.44it/s]


Epoch 4/500 — train_loss=4.65059  val_loss=4.61544  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 5 Train: 100%|██████████████████████████| 360/360 [00:38<00:00,  9.28it/s]
Epoch 5 Val: 100%|██████████████████████████████| 64/64 [00:02<00:00, 25.73it/s]


Epoch 5/500 — train_loss=4.48988  val_loss=4.46833  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 6 Train: 100%|██████████████████████████| 360/360 [00:37<00:00,  9.70it/s]
Epoch 6 Val: 100%|██████████████████████████████| 64/64 [00:02<00:00, 22.26it/s]


Epoch 6/500 — train_loss=4.32144  val_loss=4.38340  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 7 Train: 100%|██████████████████████████| 360/360 [00:39<00:00,  9.01it/s]
Epoch 7 Val: 100%|██████████████████████████████| 64/64 [00:02<00:00, 25.86it/s]


Epoch 7/500 — train_loss=4.15267  val_loss=4.27848  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 8 Train: 100%|██████████████████████████| 360/360 [00:38<00:00,  9.34it/s]
Epoch 8 Val: 100%|██████████████████████████████| 64/64 [00:02<00:00, 23.37it/s]


Epoch 8/500 — train_loss=4.04518  val_loss=4.19301  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 9 Train: 100%|██████████████████████████| 360/360 [00:37<00:00,  9.71it/s]
Epoch 9 Val: 100%|██████████████████████████████| 64/64 [00:02<00:00, 25.25it/s]


Epoch 9/500 — train_loss=3.92093  val_loss=4.15606  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 10 Train: 100%|█████████████████████████| 360/360 [00:40<00:00,  8.93it/s]
Epoch 10 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 23.86it/s]


Epoch 10/500 — train_loss=3.87599  val_loss=4.09593  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 11 Train: 100%|█████████████████████████| 360/360 [00:39<00:00,  9.07it/s]
Epoch 11 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 24.83it/s]


Epoch 11/500 — train_loss=3.76851  val_loss=4.05496  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 12 Train: 100%|█████████████████████████| 360/360 [00:38<00:00,  9.38it/s]
Epoch 12 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 22.27it/s]


Epoch 12/500 — train_loss=3.67930  val_loss=4.01540  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 13 Train: 100%|█████████████████████████| 360/360 [00:34<00:00, 10.50it/s]
Epoch 13 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 27.44it/s]


Epoch 13/500 — train_loss=3.63410  val_loss=3.98235  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 14 Train: 100%|█████████████████████████| 360/360 [00:31<00:00, 11.35it/s]
Epoch 14 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 28.17it/s]


Epoch 14/500 — train_loss=3.53167  val_loss=3.93293  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 15 Train: 100%|█████████████████████████| 360/360 [00:32<00:00, 11.18it/s]
Epoch 15 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 22.70it/s]


Epoch 15/500 — train_loss=3.47029  val_loss=3.93580  lr=5.0e-05
  ↳ No improvement for 1 epoch(s).


Epoch 16 Train: 100%|█████████████████████████| 360/360 [00:36<00:00,  9.74it/s]
Epoch 16 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 20.41it/s]


Epoch 16/500 — train_loss=3.44868  val_loss=3.91388  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 17 Train: 100%|█████████████████████████| 360/360 [00:41<00:00,  8.68it/s]
Epoch 17 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 20.74it/s]


Epoch 17/500 — train_loss=3.37780  val_loss=3.89755  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 18 Train: 100%|█████████████████████████| 360/360 [00:35<00:00, 10.00it/s]
Epoch 18 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 20.02it/s]


Epoch 18/500 — train_loss=3.32050  val_loss=3.85002  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 19 Train: 100%|█████████████████████████| 360/360 [00:38<00:00,  9.29it/s]
Epoch 19 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 20.25it/s]


Epoch 19/500 — train_loss=3.28095  val_loss=3.85798  lr=5.0e-05
  ↳ No improvement for 1 epoch(s).


Epoch 20 Train: 100%|█████████████████████████| 360/360 [00:39<00:00,  9.02it/s]
Epoch 20 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 20.80it/s]


Epoch 20/500 — train_loss=3.22935  val_loss=3.82517  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 21 Train: 100%|█████████████████████████| 360/360 [00:42<00:00,  8.45it/s]
Epoch 21 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 20.70it/s]


Epoch 21/500 — train_loss=3.20240  val_loss=3.80907  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 22 Train: 100%|█████████████████████████| 360/360 [00:38<00:00,  9.42it/s]
Epoch 22 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 27.84it/s]


Epoch 22/500 — train_loss=3.18498  val_loss=3.76892  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 23 Train: 100%|█████████████████████████| 360/360 [00:31<00:00, 11.60it/s]
Epoch 23 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 27.25it/s]


Epoch 23/500 — train_loss=3.12497  val_loss=3.78784  lr=5.0e-05
  ↳ No improvement for 1 epoch(s).


Epoch 24 Train: 100%|█████████████████████████| 360/360 [00:31<00:00, 11.44it/s]
Epoch 24 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 26.26it/s]


Epoch 24/500 — train_loss=3.08083  val_loss=3.76415  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 25 Train: 100%|█████████████████████████| 360/360 [00:42<00:00,  8.54it/s]
Epoch 25 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 20.53it/s]


Epoch 25/500 — train_loss=3.04667  val_loss=3.74696  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 26 Train: 100%|█████████████████████████| 360/360 [00:43<00:00,  8.20it/s]
Epoch 26 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 20.64it/s]


Epoch 26/500 — train_loss=3.00353  val_loss=3.73345  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 27 Train: 100%|█████████████████████████| 360/360 [00:42<00:00,  8.46it/s]
Epoch 27 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 21.40it/s]


Epoch 27/500 — train_loss=3.00477  val_loss=3.75959  lr=5.0e-05
  ↳ No improvement for 1 epoch(s).


Epoch 28 Train: 100%|█████████████████████████| 360/360 [00:41<00:00,  8.59it/s]
Epoch 28 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 21.10it/s]


Epoch 28/500 — train_loss=2.97399  val_loss=3.74448  lr=5.0e-05
  ↳ No improvement for 2 epoch(s).


Epoch 29 Train: 100%|█████████████████████████| 360/360 [00:42<00:00,  8.45it/s]
Epoch 29 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 21.12it/s]


Epoch 29/500 — train_loss=2.93440  val_loss=3.72384  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 30 Train: 100%|█████████████████████████| 360/360 [00:43<00:00,  8.25it/s]
Epoch 30 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 19.90it/s]


Epoch 30/500 — train_loss=2.87670  val_loss=3.70255  lr=5.0e-05
  ↳ New best MultiTaskModel saved.


Epoch 31 Train: 100%|█████████████████████████| 360/360 [00:43<00:00,  8.22it/s]
Epoch 31 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 17.71it/s]


Epoch 31/500 — train_loss=2.86865  val_loss=3.71390  lr=5.0e-05
  ↳ No improvement for 1 epoch(s).


Epoch 32 Train: 100%|█████████████████████████| 360/360 [00:41<00:00,  8.59it/s]
Epoch 32 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 20.45it/s]


Epoch 32/500 — train_loss=2.87922  val_loss=3.72067  lr=4.9e-05
  ↳ No improvement for 2 epoch(s).


Epoch 33 Train: 100%|█████████████████████████| 360/360 [00:44<00:00,  8.12it/s]
Epoch 33 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 19.21it/s]


Epoch 33/500 — train_loss=2.85089  val_loss=3.69623  lr=4.9e-05
  ↳ New best MultiTaskModel saved.


Epoch 34 Train: 100%|█████████████████████████| 360/360 [00:42<00:00,  8.43it/s]
Epoch 34 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 19.68it/s]


Epoch 34/500 — train_loss=2.82408  val_loss=3.70895  lr=4.9e-05
  ↳ No improvement for 1 epoch(s).


Epoch 35 Train: 100%|█████████████████████████| 360/360 [00:41<00:00,  8.74it/s]
Epoch 35 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 21.70it/s]


Epoch 35/500 — train_loss=2.75990  val_loss=3.67709  lr=4.9e-05
  ↳ New best MultiTaskModel saved.


Epoch 36 Train: 100%|█████████████████████████| 360/360 [00:38<00:00,  9.24it/s]
Epoch 36 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 20.84it/s]


Epoch 36/500 — train_loss=2.74447  val_loss=3.68558  lr=4.9e-05
  ↳ No improvement for 1 epoch(s).


Epoch 37 Train: 100%|█████████████████████████| 360/360 [00:38<00:00,  9.32it/s]
Epoch 37 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 24.60it/s]


Epoch 37/500 — train_loss=2.71042  val_loss=3.67154  lr=4.9e-05
  ↳ New best MultiTaskModel saved.


Epoch 38 Train: 100%|█████████████████████████| 360/360 [00:38<00:00,  9.47it/s]
Epoch 38 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 23.34it/s]


Epoch 38/500 — train_loss=2.69420  val_loss=3.65832  lr=4.9e-05
  ↳ New best MultiTaskModel saved.


Epoch 39 Train: 100%|█████████████████████████| 360/360 [00:41<00:00,  8.72it/s]
Epoch 39 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 21.39it/s]


Epoch 39/500 — train_loss=2.72038  val_loss=3.66981  lr=4.9e-05
  ↳ No improvement for 1 epoch(s).


Epoch 40 Train: 100%|█████████████████████████| 360/360 [00:39<00:00,  9.17it/s]
Epoch 40 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 19.40it/s]


Epoch 40/500 — train_loss=2.66745  val_loss=3.67334  lr=4.9e-05
  ↳ No improvement for 2 epoch(s).


Epoch 41 Train: 100%|█████████████████████████| 360/360 [00:40<00:00,  8.99it/s]
Epoch 41 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 27.70it/s]


Epoch 41/500 — train_loss=2.67989  val_loss=3.65137  lr=4.9e-05
  ↳ New best MultiTaskModel saved.


Epoch 42 Train: 100%|█████████████████████████| 360/360 [00:39<00:00,  9.18it/s]
Epoch 42 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 23.91it/s]


Epoch 42/500 — train_loss=2.62011  val_loss=3.66198  lr=4.9e-05
  ↳ No improvement for 1 epoch(s).


Epoch 43 Train: 100%|█████████████████████████| 360/360 [00:37<00:00,  9.48it/s]
Epoch 43 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 24.55it/s]


Epoch 43/500 — train_loss=2.61695  val_loss=3.66739  lr=4.9e-05
  ↳ No improvement for 2 epoch(s).


Epoch 44 Train: 100%|█████████████████████████| 360/360 [00:38<00:00,  9.29it/s]
Epoch 44 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 25.24it/s]


Epoch 44/500 — train_loss=2.58712  val_loss=3.67482  lr=4.9e-05
  ↳ No improvement for 3 epoch(s).


Epoch 45 Train: 100%|█████████████████████████| 360/360 [00:38<00:00,  9.24it/s]
Epoch 45 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 22.99it/s]


Epoch 45/500 — train_loss=2.60076  val_loss=3.62240  lr=4.9e-05
  ↳ New best MultiTaskModel saved.


Epoch 46 Train: 100%|█████████████████████████| 360/360 [00:39<00:00,  9.05it/s]
Epoch 46 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 24.03it/s]


Epoch 46/500 — train_loss=2.57241  val_loss=3.64758  lr=4.9e-05
  ↳ No improvement for 1 epoch(s).


Epoch 47 Train: 100%|█████████████████████████| 360/360 [00:41<00:00,  8.76it/s]
Epoch 47 Val: 100%|█████████████████████████████| 64/64 [00:03<00:00, 20.02it/s]


Epoch 47/500 — train_loss=2.57086  val_loss=3.64969  lr=4.9e-05
  ↳ No improvement for 2 epoch(s).


Epoch 48 Train: 100%|█████████████████████████| 360/360 [00:42<00:00,  8.52it/s]
Epoch 48 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 24.95it/s]


Epoch 48/500 — train_loss=2.56847  val_loss=3.64161  lr=4.9e-05
  ↳ No improvement for 3 epoch(s).


Epoch 49 Train: 100%|█████████████████████████| 360/360 [00:41<00:00,  8.59it/s]
Epoch 49 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 21.48it/s]


Epoch 49/500 — train_loss=2.53413  val_loss=3.64725  lr=4.9e-05
  ↳ No improvement for 4 epoch(s).


Epoch 50 Train: 100%|█████████████████████████| 360/360 [00:40<00:00,  8.99it/s]
Epoch 50 Val: 100%|█████████████████████████████| 64/64 [00:02<00:00, 23.56it/s]

Epoch 50/500 — train_loss=2.49843  val_loss=3.62813  lr=4.9e-05
  ↳ No improvement for 5 epoch(s).
Stopping early after 5 epochs without improvement.





In [10]:
# Sanity Check on a few samples on the test dataset
# Upon a quick check, intent classification performs reasonably well.
# And the NER model needs further training of the backbone on a larger architecture and dataset. 
# But this is a good first step as we see some okay predictions.

for idx in np.random.randint(1,100,10):
    batch = collate([test_ds[int(idx)]])
    input_ids = batch["input_ids"].to(device)
    attention_mask = batch["attention_mask"].to(device)
    slot_labels = batch["slot_labels"].to(device)
    intent_labels = batch["intent_labels"].to(device)


    with torch.no_grad():
        intent_logits, ner_logits = model(input_ids, attention_mask)
        intent_loss = intent_loss_fn(intent_logits, intent_labels)
        ner_loss = ner_loss_fn(ner_logits.view(-1, num_ner_tags), slot_labels.view(-1))

        intent_preds = int(np.argmax(intent_logits))
        intent_labels = intent_labels[0]    
        
        ner_preds = list(map(int, (map(np.argmax, ner_logits[0]))))
        ner_labels = list(map(int, np.array(slot_labels)[0]))

    print(f"intent prediction: {intent_preds}")
    print(f"intent label     : {intent_labels}")
    print(f"ner prediction: {ner_preds}")
    print(f"ner lablel    : {ner_labels}")
    print()


intent prediction: 46
intent label     : 46
ner prediction: [0, 0, 0, 0, 0, 7, 7, 7, 7, 7, 7, 7]
ner lablel    : [111, 0, 0, 111, 0, 111, 7, 7, 7, 111, 0, 0]

intent prediction: 31
intent label     : 31
ner prediction: [0, 0, 0, 0, 0, 59, 59, 59, 59, 59, 59, 59, 59, 0]
ner lablel    : [111, 0, 0, 111, 0, 0, 111, 0, 111, 0, 111, 0, 0, 0]

intent prediction: 1
intent label     : 1
ner prediction: [0, 0, 0, 0, 0, 77, 77, 77, 77, 77, 0, 0, 0, 0]
ner lablel    : [111, 0, 111, 0, 111, 77, 77, 77, 77, 77, 111, 20, 20, 20]

intent prediction: 57
intent label     : 57
ner prediction: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 69, 69, 69, 69, 69, 69, 69, 69]
ner lablel    : [111, 0, 0, 111, 0, 111, 0, 111, 0, 0, 0, 111, 0, 111, 0, 111, 69, 69, 69, 69]

intent prediction: 10
intent label     : 7
ner prediction: [0, 0, 0, 0, 0, 0, 0, 0]
ner lablel    : [111, 0, 0, 0, 111, 0, 0, 0]

intent prediction: 53
intent label     : 1
ner prediction: [0, 0, 80, 0, 0, 0]
ner lablel    : [111, 22, 22, 22, 111, 24]



  ner_labels = list(map(int, np.array(slot_labels)[0]))
