In [1]:
import torch
import torch.nn.functional as F

from datasets import load_dataset

In [2]:
dataset = load_dataset("surrey-nlp/PLOD-CW-25")

HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/surrey-nlp/PLOD-CW-25/resolve/main/README.md
Retrying in 1s [Retry 1/5].
HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/surrey-nlp/PLOD-CW-25/resolve/main/README.md
Retrying in 2s [Retry 2/5].
HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/surrey-nlp/PLOD-CW-25/resolve/main/README.md
Retrying in 4s [Retry 3/5].
HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/surrey-nlp/PLOD-CW-25/resolve/main/README.md
Retrying in 8s [Retry 4/5].
HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/surrey-nlp/PLOD-CW-25/resolve/main/README.md
Retrying in 8s [Retry 5/5].
HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/surrey-nlp/PLOD-CW-25/resolve/main/README.md
HTTP Error 429 thrown while requesting HEAD https://huggingface.co/datasets/surrey-nlp/PLOD-CW-25/resolve/9e3083d6df56ff798c62bd3fe0ba

In [3]:
labels = ["O", "B-AC", "B-LF", "I-LF"]
n_labels = len(labels)
ltoi = {l: i for i, l in enumerate(labels)}
itol = {i: l for l, i in ltoi.items()}

In [4]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

In [5]:
def compute_embeds(examples):
    encoded_input = tokenizer(examples['tokens'], padding=True, truncation=True, return_tensors='pt', is_split_into_words=True)

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    return sentence_embeddings

In [6]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True, # Crucial for pre-tokenized input
        max_length=512,
        padding="max_length"

    )

    labels = []
    for i, label_sequence in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None: # Special tokens ([CLS], [SEP])
                label_ids.append(-100)
            elif word_idx != previous_word_idx: # First token of a new word
                label_ids.append(ltoi[label_sequence[word_idx]])
            else: # Subsequent tokens of the same word
                label_ids.append(-100)


            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [7]:
data = dataset.map(tokenize_and_align_labels, batched = True)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

In [8]:
train_data, train_labels, train_attention_mask = data['train']['input_ids'], data['train']['labels'], data['train']['attention_mask']
val_data, val_labels, val_attention_mask = data['validation']['input_ids'], data['validation']['labels'], data['validation']['attention_mask']
test_data, test_labels, test_attention_mask = data['test']['input_ids'], data['test']['labels'], data['test']['attention_mask']

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 8

def get_batch(split = "train"):
  data = train_data if split == "train" else val_data
  labels = train_labels if split == "train" else val_labels
  attention_mask = train_attention_mask if split == "train" else val_attention_mask
  ix = torch.randint(len(data), (batch_size,))
  x = torch.stack([torch.tensor(data[i]).long() for i in ix])
  y = torch.stack([torch.tensor(labels[i]).long() for i in ix])
  a = torch.stack([torch.tensor(attention_mask[i]) for i in ix])
  return x.to(device), y.to(device), a.to(device)


@torch.no_grad()
def estimate_loss(eval_steps):
  out = {}
  model_ffn.eval()
  for split in ["train", "validation"]:
    losses = torch.zeros(eval_steps)
    for k in range(eval_steps):
      x, y, a = get_batch(split)
      logits = model_ffn(x, attention_mask = a)
      loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1))
      losses[k] = loss.item()
    out[split] = losses.mean()
  model_ffn.train()
  return out

In [10]:
import torch
import torch.nn as nn

In [16]:
class MiniLM_FFN(nn.Module):
    def __init__(self, model, dim, n_labels):
        super().__init__()
        self.minilm = model
        self.n_emb = model.config.hidden_size # Should be 384

        
        for param in self.minilm.parameters():
            param.requires_grad = False

        self.lstm = nn.LSTM(self.n_emb, dim, num_layers=1, batch_first=True, bidirectional=True)
        self.proj = nn.Linear(dim * 2, n_labels)

    def forward(self, idxs, attention_mask):
        # Get MiniLM embeddings
        # outputs.last_hidden_state shape: (batch_size, seq_len, minilm_embedding_dim)
        x = self.minilm(input_ids=idxs, attention_mask=attention_mask)
        tok_embs = x.last_hidden_state

        # Apply FFN
        logits = self.proj(self.lstm(tok_embs)[0])
    
        return logits

In [17]:
torch.set_float32_matmul_precision("high")

In [23]:
model_ffn = MiniLM_FFN(model, 256, 4).to(device)
model_ffn = torch.compile(model_ffn)

In [24]:
from tqdm import tqdm

optim = torch.optim.AdamW(model_ffn.parameters(), lr = 1e-3)
max_steps = 1000

In [25]:
lossi = []
lri = []

for step in tqdm(range(max_steps)):
  # for g in optim.param_groups:
  #   g['lr'] = lrs[step]

  x, y, a = get_batch("train")
  optim.zero_grad()
  with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    logits = model_ffn(x, attention_mask = a)
    loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1))
  loss.backward()
  optim.step()
  # lri.append(lre[step])
  # lossi.append(loss.item())
  if step % 500 == 0:
    losses = estimate_loss(200)
    print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['validation']:.4f}")


  1%|          | 7/1000 [00:08<14:31,  1.14it/s]  

step 0: train loss 1.3146, val loss 1.3137


 51%|█████     | 511/1000 [00:25<02:29,  3.28it/s]

step 500: train loss 0.2287, val loss 0.3506


100%|██████████| 1000/1000 [00:34<00:00, 28.94it/s]


In [26]:
from seqeval.metrics import precision_score, recall_score, f1_score, classification_report

In [27]:
@torch.no_grad()
def evaluate_model(split="test"):
    """Evaluate model performance on given split with seqeval metrics"""
    model.eval()
    
    if split == "test":
        data_input_ids = test_data
        data_labels = test_labels
        data_attention_mask = test_attention_mask
    elif split == "validation":
        data_input_ids = val_data
        data_labels = val_labels
        data_attention_mask = val_attention_mask
    else:
        data_input_ids = train_data
        data_labels = train_labels
        data_attention_mask = train_attention_mask
    
    # Process in smaller batches to avoid OOM
    batch_size_eval = 16
    all_true_labels = []
    all_pred_labels = []
    
    # Process the entire dataset
    for i in tqdm(range(0, len(data_input_ids), batch_size_eval), desc=f"Evaluating on {split}"):
        # Get batch
        batch_input_ids = torch.tensor(data_input_ids[i:i+batch_size_eval]).to(device)
        batch_labels = torch.tensor(data_labels[i:i+batch_size_eval]).to(device)
        batch_attention_mask = torch.tensor(data_attention_mask[i:i+batch_size_eval]).to(device)
        
        # Get predictions
        logits = model_ffn(batch_input_ids, attention_mask=batch_attention_mask)
        predictions = torch.argmax(logits, dim=-1)
        
        # Convert predictions and labels to lists for seqeval
        for j in range(len(batch_input_ids)):
            true_label_ids = batch_labels[j].cpu().numpy()
            pred_label_ids = predictions[j].cpu().numpy()
            
            # Convert IDs to labels, handling special tokens
            true_seq = []
            pred_seq = []
            
            for true_id, pred_id, mask in zip(true_label_ids, pred_label_ids, batch_attention_mask[j]):
                if mask == 1 and true_id != -100:  # Only evaluate on non-padding and non-special tokens
                    true_seq.append(itol[true_id.item()])
                    pred_seq.append(itol[pred_id.item()])
            
            if true_seq:  # Only add if not empty
                all_true_labels.append(true_seq)
                all_pred_labels.append(pred_seq)
    
    # Calculate metrics using seqeval
    precision = precision_score(all_true_labels, all_pred_labels)
    recall = recall_score(all_true_labels, all_pred_labels)
    f1 = f1_score(all_true_labels, all_pred_labels)
    report = classification_report(all_true_labels, all_pred_labels)
    
    print(f"\n=== Evaluation on {split} split ===")
    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print("\nDetailed Classification Report:")
    print(report)
    
    return {
        "f1": f1,
        "precision": precision,
        "recall": recall,
        "report": report
    }

In [28]:
evaluate_model()

Evaluating on test: 100%|██████████| 16/16 [00:04<00:00,  3.27it/s]



=== Evaluation on test split ===
F1 Score: 0.7381
Precision: 0.7253
Recall: 0.7514

Detailed Classification Report:
              precision    recall  f1-score   support

          AC       0.83      0.82      0.82       797
          LF       0.58      0.64      0.60       482

   micro avg       0.73      0.75      0.74      1279
   macro avg       0.70      0.73      0.71      1279
weighted avg       0.73      0.75      0.74      1279



{'f1': 0.7380952380952381,
 'precision': 0.7252830188679246,
 'recall': 0.7513682564503519,
 'report': '              precision    recall  f1-score   support\n\n          AC       0.83      0.82      0.82       797\n          LF       0.58      0.64      0.60       482\n\n   micro avg       0.73      0.75      0.74      1279\n   macro avg       0.70      0.73      0.71      1279\nweighted avg       0.73      0.75      0.74      1279\n'}

In [30]:
evaluate_model('validation')

Evaluating on validation: 100%|██████████| 10/10 [00:01<00:00,  7.35it/s]



=== Evaluation on validation split ===
F1 Score: 0.6911
Precision: 0.6765
Recall: 0.7064

Detailed Classification Report:
              precision    recall  f1-score   support

          AC       0.81      0.77      0.79       508
          LF       0.50      0.59      0.54       306

   micro avg       0.68      0.71      0.69       814
   macro avg       0.65      0.68      0.67       814
weighted avg       0.69      0.71      0.70       814



{'f1': 0.6911057692307693,
 'precision': 0.6764705882352942,
 'recall': 0.7063882063882064,
 'report': '              precision    recall  f1-score   support\n\n          AC       0.81      0.77      0.79       508\n          LF       0.50      0.59      0.54       306\n\n   micro avg       0.68      0.71      0.69       814\n   macro avg       0.65      0.68      0.67       814\nweighted avg       0.69      0.71      0.70       814\n'}