In [1]:
import torch
import torch.nn.functional as F

from datasets import load_dataset

In [2]:
dataset = load_dataset("surrey-nlp/PLOD-CW-25")

In [3]:
labels = ["O", "B-AC", "B-LF", "I-LF"]
n_labels = len(labels)
ltoi = {l: i for i, l in enumerate(labels)}
itol = {i: l for l, i in ltoi.items()}

In [18]:
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

#Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


# Load model from HuggingFace Hub
tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')
model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2')

In [19]:
def compute_embeds(examples):
    encoded_input = tokenizer(examples['tokens'], padding=True, truncation=True, return_tensors='pt', is_split_into_words=True)

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
    return sentence_embeddings

In [48]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True, # Crucial for pre-tokenized input
        max_length=512,
        padding="max_length"

    )

    labels = []
    for i, label_sequence in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None: # Special tokens ([CLS], [SEP])
                label_ids.append(-100)
            elif word_idx != previous_word_idx: # First token of a new word
                label_ids.append(ltoi[label_sequence[word_idx]])
            else: # Subsequent tokens of the same word
                label_ids.append(-100)


            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [49]:
data = dataset.map(tokenize_and_align_labels, batched = True)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/250 [00:00<?, ? examples/s]

Map:   0%|          | 0/150 [00:00<?, ? examples/s]

In [50]:
train_data, train_labels, train_attention_mask = data['train']['input_ids'], data['train']['labels'], data['train']['attention_mask']
val_data, val_labels, val_attention_mask = data['validation']['input_ids'], data['validation']['labels'], data['validation']['attention_mask']
test_data, test_labels, test_attention_mask = data['test']['input_ids'], data['test']['labels'], data['test']['attention_mask']

In [76]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 8

def get_batch(split = "train"):
  data = train_data if split == "train" else val_data
  labels = train_labels if split == "train" else val_labels
  attention_mask = train_attention_mask if split == "train" else val_attention_mask
  ix = torch.randint(len(data), (batch_size,))
  x = torch.stack([torch.tensor(data[i]).long() for i in ix])
  y = torch.stack([torch.tensor(labels[i]).long() for i in ix])
  a = torch.stack([torch.tensor(attention_mask[i]) for i in ix])
  return x.to(device), y.to(device), a.to(device)


@torch.no_grad()
def estimate_loss(eval_steps):
  out = {}
  model_ffn.eval()
  for split in ["train", "validation"]:
    losses = torch.zeros(eval_steps)
    for k in range(eval_steps):
      x, y, a = get_batch(split)
      logits = model_ffn(x, attention_mask = a)
      loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1))
      losses[k] = loss.item()
    out[split] = losses.mean()
  model_ffn.train()
  return out

In [77]:
import torch
import torch.nn as nn

In [78]:
class MiniLM_FFN(nn.Module):
    def __init__(self, model, dim, n_labels):
        super().__init__()
        self.minilm = model
        self.n_emb = model.config.hidden_size # Should be 384

        
        for param in self.minilm.parameters():
            param.requires_grad = False

        self.fc1 = nn.Linear(self.n_emb, dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.fc2 = nn.Linear(dim, n_labels)

    def forward(self, idxs, attention_mask):
        # Get MiniLM embeddings
        # outputs.last_hidden_state shape: (batch_size, seq_len, minilm_embedding_dim)
        x = self.minilm(input_ids=idxs, attention_mask=attention_mask)
        tok_embs = x.last_hidden_state

        # Apply FFN
        x = self.fc1(tok_embs)
        x = self.relu(x)
        x = self.dropout(x)
        logits = self.fc2(x) # (batch_size, seq_len, num_labels)
        return logits

In [79]:
torch.set_float32_matmul_precision("high")

In [95]:
model_ffn = MiniLM_FFN(model, 256, 4).to(device)
model_ffn = torch.compile(model_ffn)

In [96]:
from tqdm import tqdm

optim = torch.optim.AdamW(model_ffn.parameters(), lr = 1e-3)
max_steps = 5000

In [97]:
lossi = []
lri = []

for step in tqdm(range(max_steps)):
  # for g in optim.param_groups:
  #   g['lr'] = lrs[step]

  x, y, a = get_batch("train")
  optim.zero_grad()
  with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
    logits = model_ffn(x, attention_mask = a)
    loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), y.view(-1))
  loss.backward()
  optim.step()
  # lri.append(lre[step])
  # lossi.append(loss.item())
  if step % 500 == 0:
    losses = estimate_loss(200)
    print(f"step {step}: train loss {losses['train']:.4f}, val loss {losses['validation']:.4f}")


  0%|          | 16/5000 [00:05<22:22,  3.71it/s] 

step 0: train loss 1.2960, val loss 1.2971


 11%|█         | 526/5000 [00:15<06:41, 11.13it/s] 

step 500: train loss 0.4412, val loss 0.5021


 20%|██        | 1021/5000 [00:24<05:58, 11.10it/s]

step 1000: train loss 0.3733, val loss 0.4601


 31%|███       | 1529/5000 [00:34<05:13, 11.07it/s] 

step 1500: train loss 0.3567, val loss 0.4624


 40%|████      | 2020/5000 [00:43<04:29, 11.04it/s] 

step 2000: train loss 0.3456, val loss 0.4381


 50%|█████     | 2521/5000 [00:53<03:49, 10.79it/s] 

step 2500: train loss 0.3166, val loss 0.4514


 60%|██████    | 3018/5000 [01:03<03:09, 10.45it/s] 

step 3000: train loss 0.3016, val loss 0.4630


 70%|███████   | 3518/5000 [01:12<02:20, 10.58it/s] 

step 3500: train loss 0.2928, val loss 0.4455


 80%|████████  | 4017/5000 [01:22<01:32, 10.59it/s] 

step 4000: train loss 0.2831, val loss 0.4476


 90%|█████████ | 4524/5000 [01:31<00:46, 10.34it/s] 

step 4500: train loss 0.2718, val loss 0.4618


100%|██████████| 5000/5000 [01:35<00:00, 52.54it/s] 


In [103]:
from seqeval.metrics import precision_score, recall_score, f1_score, classification_report

In [104]:
@torch.no_grad()
def evaluate_model(split="test"):
    """Evaluate model performance on given split with seqeval metrics"""
    model.eval()
    
    if split == "test":
        data_input_ids = test_data
        data_labels = test_labels
        data_attention_mask = test_attention_mask
    elif split == "validation":
        data_input_ids = val_data
        data_labels = val_labels
        data_attention_mask = val_attention_mask
    else:
        data_input_ids = train_data
        data_labels = train_labels
        data_attention_mask = train_attention_mask
    
    # Process in smaller batches to avoid OOM
    batch_size_eval = 16
    all_true_labels = []
    all_pred_labels = []
    
    # Process the entire dataset
    for i in tqdm(range(0, len(data_input_ids), batch_size_eval), desc=f"Evaluating on {split}"):
        # Get batch
        batch_input_ids = torch.tensor(data_input_ids[i:i+batch_size_eval]).to(device)
        batch_labels = torch.tensor(data_labels[i:i+batch_size_eval]).to(device)
        batch_attention_mask = torch.tensor(data_attention_mask[i:i+batch_size_eval]).to(device)
        
        # Get predictions
        logits = model_ffn(batch_input_ids, attention_mask=batch_attention_mask)
        predictions = torch.argmax(logits, dim=-1)
        
        # Convert predictions and labels to lists for seqeval
        for j in range(len(batch_input_ids)):
            true_label_ids = batch_labels[j].cpu().numpy()
            pred_label_ids = predictions[j].cpu().numpy()
            
            # Convert IDs to labels, handling special tokens
            true_seq = []
            pred_seq = []
            
            for true_id, pred_id, mask in zip(true_label_ids, pred_label_ids, batch_attention_mask[j]):
                if mask == 1 and true_id != -100:  # Only evaluate on non-padding and non-special tokens
                    true_seq.append(itol[true_id.item()])
                    pred_seq.append(itol[pred_id.item()])
            
            if true_seq:  # Only add if not empty
                all_true_labels.append(true_seq)
                all_pred_labels.append(pred_seq)
    
    # Calculate metrics using seqeval
    precision = precision_score(all_true_labels, all_pred_labels)
    recall = recall_score(all_true_labels, all_pred_labels)
    f1 = f1_score(all_true_labels, all_pred_labels)
    report = classification_report(all_true_labels, all_pred_labels)
    
    print(f"\n=== Evaluation on {split} split ===")
    print(f"F1 Score: {f1:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print("\nDetailed Classification Report:")
    print(report)
    
    return {
        "f1": f1,
        "precision": precision,
        "recall": recall,
        "report": report
    }

In [None]:
evaluate_model()

Evaluating on test: 100%|██████████| 16/16 [00:01<00:00,  9.33it/s]



=== Evaluation on test split ===
F1 Score: 0.4796
Precision: 0.4169
Recall: 0.5645

Detailed Classification Report:
              precision    recall  f1-score   support

          AC       0.75      0.75      0.75       797
          LF       0.13      0.26      0.17       482

   micro avg       0.42      0.56      0.48      1279
   macro avg       0.44      0.50      0.46      1279
weighted avg       0.52      0.56      0.53      1279

{'f1': 0.4795748920624377, 'precision': 0.4168591224018476, 'recall': 0.5645035183737295, 'report': '              precision    recall  f1-score   support\n\n          AC       0.75      0.75      0.75       797\n          LF       0.13      0.26      0.17       482\n\n   micro avg       0.42      0.56      0.48      1279\n   macro avg       0.44      0.50      0.46      1279\nweighted avg       0.52      0.56      0.53      1279\n'}


In [107]:
evaluate_model('validation')

Evaluating on validation: 100%|██████████| 10/10 [00:01<00:00,  9.77it/s]



=== Evaluation on validation split ===
F1 Score: 0.4783
Precision: 0.4160
Recall: 0.5627

Detailed Classification Report:
              precision    recall  f1-score   support

          AC       0.75      0.75      0.75       508
          LF       0.13      0.25      0.17       306

   micro avg       0.42      0.56      0.48       814
   macro avg       0.44      0.50      0.46       814
weighted avg       0.51      0.56      0.53       814



{'f1': 0.4783289817232376,
 'precision': 0.4159854677565849,
 'recall': 0.5626535626535627,
 'report': '              precision    recall  f1-score   support\n\n          AC       0.75      0.75      0.75       508\n          LF       0.13      0.25      0.17       306\n\n   micro avg       0.42      0.56      0.48       814\n   macro avg       0.44      0.50      0.46       814\nweighted avg       0.51      0.56      0.53       814\n'}

In [42]:
def labels(examples):
    labels = []
    for i, label_sequence in enumerate(examples["ner_tags"]):
        encoded = [ltoi[k] for k in label_sequence]
        labels.append(torch.tensor(encoded).long())
    encoded = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
    return encoded

In [45]:
train_data = compute_embeds(dataset['train'])
val_data = compute_embeds(dataset['validation'])
test_data = compute_embeds(dataset['test'])

train_labels = labels(dataset['train'])
val_labels = labels(dataset['validation'])
test_labels = labels(dataset['test'])

In [46]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 8

def get_batch(split = "train"):
  data = train_data if split == "train" else val_data
  labels = train_labels if split =='train' else val_labels
  ix = torch.randint(len(data), (batch_size,))
  x = torch.stack([torch.tensor(data[i]).long() for i in ix])
  y = torch.stack([torch.tensor(labels[i]).long() for i in ix])
  return x.to(device), y.to(device), 


@torch.no_grad()
def estimate_loss(eval_steps):
  out = {}
  model.eval()
  for split in ["train", "validation"]:
    losses = torch.zeros(eval_steps)
    for k in range(eval_steps):
      x, y = get_batch(split)
      logits, loss = model(x, y)
      losses[k] = loss.item()
    out[split] = losses.mean()
  model.train()
  return out