In [4]:
import torch
import torch.nn.functional as F

from datasets import load_dataset

In [5]:
dataset = load_dataset("surrey-nlp/PLOD-CW-25")

In [6]:
labels = ["O", "B-AC", "B-LF", "I-LF"]
n_labels = len(labels)
ltoi = {l: i for i, l in enumerate(labels)}
itol = {i: l for l, i in ltoi.items()}

In [7]:
from transformers import AutoTokenizer

model_checkpoint = "microsoft/deberta-v3-base" # Or deberta-v3-large, etc.
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# Check if the tokenizer is a fast tokenizer (it should be for DeBERTa-v3)
assert tokenizer.is_fast, "Only fast tokenizers are supported for this example."



In [2]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
        max_length=512,
        padding="max_length"
    )

    aligned_labels_batch = []
    for i, label_sequence in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids_for_sequence = []
        for word_idx in word_ids:
            if word_idx is None:  # Special tokens ([CLS], [SEP]) or padding
                label_ids_for_sequence.append(ltoi["O"]) # Assign 'O', will be masked by CRF if padding
            elif word_idx != previous_word_idx:  # First token of a new word
                label_ids_for_sequence.append(ltoi[label_sequence[word_idx]])
            else:  # Subsequent tokens of the same word
                current_label_str = label_sequence[word_idx]
                # Propagate I-tag if B-tag, otherwise keep current tag (O or I-tag)
                if current_label_str.startswith("B-"):
                    related_i_tag = "I-" + current_label_str[2:]
                    if related_i_tag in ltoi:
                        label_ids_for_sequence.append(ltoi[related_i_tag])
                    else: # No specific I-tag, e.g. B-AC without I-AC
                        label_ids_for_sequence.append(ltoi[current_label_str]) # or ltoi["O"]
                else:
                    label_ids_for_sequence.append(ltoi[current_label_str])
            previous_word_idx = word_idx
        aligned_labels_batch.append(label_ids_for_sequence)
    tokenized_inputs["labels"] = aligned_labels_batch
    return tokenized_inputs

In [8]:
data = dataset.map(tokenize_and_align_labels, batched = True)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

In [18]:
train_data, train_labels, train_attention_mask = data['train']['input_ids'], data['train']['labels'], data['train']['attention_mask']
val_data, val_labels, val_attention_mask = data['validation']['input_ids'], data['validation']['labels'], data['validation']['attention_mask']
test_data, test_labels, test_attention_mask = data['test']['input_ids'], data['test']['labels'], data['test']['attention_mask']

In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 8

def get_batch(split = "train"):
  data = train_data if split == "train" else val_data
  labels = train_labels if split == "train" else val_labels
  attention_mask = train_attention_mask if split == "train" else val_attention_mask
  ix = torch.randint(len(data), (batch_size,))
  x = torch.stack([torch.tensor(data[i]).long() for i in ix])
  y = torch.stack([torch.tensor(labels[i]).long() for i in ix])
  a = torch.stack([torch.tensor(attention_mask[i]) for i in ix])
  return x.to(device), y.to(device), a.to(device)

@torch.no_grad()
def estimate_loss(eval_steps):
  out = {}
  model.eval()
  for split in ["train", "validation"]:
    losses = torch.zeros(eval_steps)
    for k in range(eval_steps):
      x, y, a = get_batch(split)
      loss, _ = model(x, attention_mask = a, labels = y)
      losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out

In [43]:
from transformers import AutoModelForTokenClassification
from torchcrf import CRF

In [44]:
class DebertaCrfForTokenClassification(torch.nn.Module):
    def __init__(self, model_name, num_labels, id2label, label2id):
        super().__init__()
        self.num_labels = num_labels
        self.deberta = AutoModelForTokenClassification.from_pretrained(
            model_name,
            num_labels=num_labels, # DeBERTa's classifier head outputs scores for each label
            id2label=id2label,
            label2id=label2id
        )
        self.crf = CRF(num_tags=num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.deberta(input_ids=input_ids, attention_mask=attention_mask)
        emissions = outputs.logits # Shape: (batch_size, seq_length, num_labels)

        if labels is not None:
            loss = -self.crf(emissions, labels, mask=attention_mask.bool(), reduction='mean')
            return loss, emissions
        else:

            decoded_tags = self.crf.decode(emissions, mask=attention_mask.bool())
            return decoded_tags, emissions


In [45]:
model = DebertaCrfForTokenClassification(
    model_checkpoint,
    num_labels=n_labels,
    id2label=itol,
    label2id=ltoi
).to(device)

Some weights of DebertaV2ForTokenClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [46]:
from tqdm import tqdm

optim = torch.optim.AdamW(model.parameters(), lr = 1e-5)
max_steps = 1000

In [47]:
lossi = []
lri = []

for step in tqdm(range(max_steps)):
  # for g in optim.param_groups:
  #   g['lr'] = lrs[step]

  x, y, a = get_batch("train")
  optim.zero_grad()
  with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    loss, logits = model(x, attention_mask = a, labels = y)
  loss.backward()
  optim.step()
  # lri.append(lre[step])
  # lossi.append(loss.item())
  if step % 100 == 0:
    losses = estimate_loss(200)
    print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['validation']:.4f}")


  0%|          | 1/1000 [01:27<24:20:10, 87.70s/it]

step 0: train loss 63.5596, val loss 64.2739


 10%|█         | 101/1000 [03:39<6:43:25, 26.92s/it]

step 100: train loss 14.3790, val loss 16.6687


 20%|██        | 201/1000 [05:51<5:58:47, 26.94s/it]

step 200: train loss 11.9993, val loss 14.0732


 30%|███       | 301/1000 [08:02<5:14:00, 26.95s/it]

step 300: train loss 9.2189, val loss 12.3455


 40%|████      | 401/1000 [10:14<4:29:19, 26.98s/it]

step 400: train loss 8.1032, val loss 12.6389


 50%|█████     | 501/1000 [12:26<3:44:28, 26.99s/it]

step 500: train loss 7.9043, val loss 12.9673


 60%|██████    | 601/1000 [14:38<2:59:19, 26.97s/it]

step 600: train loss 7.3925, val loss 12.5551


 70%|███████   | 701/1000 [16:50<2:14:36, 27.01s/it]

step 700: train loss 8.0750, val loss 12.9865


 80%|████████  | 800/1000 [17:47<04:26,  1.33s/it]  


KeyboardInterrupt: 

In [48]:
from seqeval.metrics import f1_score, recall_score, precision_score, classification_report

In [49]:
@torch.no_grad()
def evaluate_model_crf(split="test", current_model=None, current_device=None,
                       current_itol=None,
                       current_test_data=None, current_test_labels=None, current_test_attention_mask=None,
                       current_val_data=None, current_val_labels=None, current_val_attention_mask=None,
                       current_train_data=None, current_train_labels=None, current_train_attention_mask=None):
    """Evaluate CRF model performance on given split with seqeval metrics"""
    if current_model is None: global model
    else: model = current_model # Use passed model if provided

    if current_device is None: global device
    else: device = current_device

    if current_itol is None: global itol
    else: itol = current_itol

    model.eval()

    # Select data based on split
    # Using passed-in data if available, otherwise global
    if split == "test":
        data_input_ids = current_test_data if current_test_data is not None else test_data
        data_labels = current_test_labels if current_test_labels is not None else test_labels
        data_attention_mask = current_test_attention_mask if current_test_attention_mask is not None else test_attention_mask
    elif split == "validation":
        data_input_ids = current_val_data if current_val_data is not None else val_data
        data_labels = current_val_labels if current_val_labels is not None else val_labels
        data_attention_mask = current_val_attention_mask if current_val_attention_mask is not None else val_attention_mask
    elif split == "train":
        data_input_ids = current_train_data if current_train_data is not None else train_data
        data_labels = current_train_labels if current_train_labels is not None else train_labels
        data_attention_mask = current_train_attention_mask if current_train_attention_mask is not None else train_attention_mask
    else:
        raise ValueError(f"Invalid split: {split}. Choose from 'train', 'validation', 'test'.")

    # Process in smaller batches to avoid OOM
    batch_size_eval = 16  # You can adjust this
    all_true_labels_str = []
    all_pred_labels_str = []

    # Process the entire dataset for the chosen split
    for i in tqdm(range(0, len(data_input_ids), batch_size_eval), desc=f"Evaluating on {split} with CRF"):
        # Get batch by slicing the pre-loaded lists of lists
        batch_input_ids_list = data_input_ids[i:i+batch_size_eval]
        batch_labels_list = data_labels[i:i+batch_size_eval]
        batch_attention_mask_list = data_attention_mask[i:i+batch_size_eval]

        # Convert to tensors
        batch_input_ids = torch.tensor(batch_input_ids_list, dtype=torch.long).to(device)
        batch_labels_ids = torch.tensor(batch_labels_list, dtype=torch.long).to(device) # True labels
        batch_attention_mask = torch.tensor(batch_attention_mask_list, dtype=torch.uint8).to(device)

        # Get predictions from CRF model (no labels passed for inference)
        # The CRF model's forward pass (without labels) returns (decoded_tags, emissions)
        if device.type == 'cuda':
            with torch.autocast(device_type='cuda', dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16):
                predicted_tags_batch, _ = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
        else:
            predicted_tags_batch, _ = model(input_ids=batch_input_ids, attention_mask=batch_attention_mask)
        # predicted_tags_batch is a list of lists of predicted tag INDICES from crf.decode()

        # Convert predictions and true labels to string lists for seqeval
        for j in range(batch_input_ids.size(0)): # Iterate over each sample in the batch
            true_label_ids_for_sample = batch_labels_ids[j].cpu().tolist()
            # predicted_tags_batch[j] is already a list of python ints (tag indices)
            pred_label_ids_for_sample = predicted_tags_batch[j]
            current_sample_mask = batch_attention_mask[j].cpu().tolist()

            true_seq_str = []
            pred_seq_str = []

            for k in range(len(true_label_ids_for_sample)): # Iterate over tokens in the sequence
                if current_sample_mask[k] == 1:  # Only consider active tokens based on attention mask
                    true_seq_str.append(itol[true_label_ids_for_sample[k]])
                    pred_seq_str.append(itol[pred_label_ids_for_sample[k]])
                # else: # This token is padding, so we stop or ignore
                #     If we want to cut sequences at the first padding:
                #     if len(true_seq_str) > 0: # ensure we are not at the beginning of padding
                #         break

            if true_seq_str:  # Only add if the sequence (after masking) is not empty
                all_true_labels_str.append(true_seq_str)
                all_pred_labels_str.append(pred_seq_str)

    # Calculate metrics using seqeval
    # Ensure that all_true_labels_str and all_pred_labels_str are not empty
    if not all_true_labels_str or not all_pred_labels_str:
        print(f"No data to evaluate for split {split}. True labels count: {len(all_true_labels_str)}, Pred labels count: {len(all_pred_labels_str)}")
        return {
            "f1": 0.0,
            "precision": 0.0,
            "recall": 0.0,
            "report": "No data to evaluate."
        }

    precision = precision_score(all_true_labels_str, all_pred_labels_str)
    recall = recall_score(all_true_labels_str, all_pred_labels_str)
    f1 = f1_score(all_true_labels_str, all_pred_labels_str)
    report = classification_report(all_true_labels_str, all_pred_labels_str, digits=4)

    print(f"\n=== CRF Model Evaluation on {split} split ===")
    print(f"F1 Score (overall): {f1:.4f}")
    print(f"Precision (overall): {precision:.4f}")
    print(f"Recall (overall): {recall:.4f}")
    print("\nDetailed Classification Report:")
    print(report)

    model.train() # Set model back to training mode

    return {
        "f1": f1,
        "precision": precision,
        "recall": recall,
        "report": report
    }

In [50]:
evaluate_model_crf()

Evaluating on test with CRF: 100%|██████████| 16/16 [00:02<00:00,  5.50it/s]



=== CRF Model Evaluation on test split ===
F1 Score (overall): 0.8634
Precision (overall): 0.8474
Recall (overall): 0.8799

Detailed Classification Report:
              precision    recall  f1-score   support

          AC     0.8799    0.8733    0.8766      1342
          LF     0.7705    0.8983    0.8295       482

   micro avg     0.8474    0.8799    0.8634      1824
   macro avg     0.8252    0.8858    0.8530      1824
weighted avg     0.8510    0.8799    0.8641      1824



{'f1': 0.8633674018289403,
 'precision': 0.8474128827877508,
 'recall': 0.8799342105263158,
 'report': '              precision    recall  f1-score   support\n\n          AC     0.8799    0.8733    0.8766      1342\n          LF     0.7705    0.8983    0.8295       482\n\n   micro avg     0.8474    0.8799    0.8634      1824\n   macro avg     0.8252    0.8858    0.8530      1824\nweighted avg     0.8510    0.8799    0.8641      1824\n'}