# [Fully Frozen] Prototypical Networks (Token-level Few-shot NER)

This notebook implements a fully frozen variant of prototypical networks for token-level NER.
- Encoder (e.g., BioBERT) is fully frozen.
- No trainable projection head (Identity).
- Prototypes are computed from the train split and used for evaluation on dev/test.

Note: This is separate from the [Projection-Head Fine-tune] baseline in `prototypical_networks_baseline.ipynb`.

In [1]:
!pip install seqeval -q
!pip install -U transformers


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone


In [2]:
# Optional: Google Drive (Colab)
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [23]:
# [Fully Frozen] Setup paths, load data
import os, random, numpy as np
from pathlib import Path
from datasets import Dataset, DatasetDict
from transformers import (AutoTokenizer, DataCollatorForTokenClassification, TrainingArguments)
from seqeval.metrics import classification_report, f1_score

BASE = Path('/content/drive/MyDrive/small_data_NER')  # adjust if not on Colab/Drive
DATA_DIR = BASE / 'conll/fewshot_k20_seed42_mention'  # change to k1/k10/k20 variants as needed
OUT_DIR  = BASE / 'results' / 'proto_net_full_frozen_k5_full'
OUT_DIR.mkdir(parents=True, exist_ok=True)

def read_conll(path):
    sents, tokens, labels = [], [], []
    with open(path, encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if not line:
                if tokens:
                    sents.append({'tokens': tokens, 'ner_tags': labels})
                    tokens, labels = [], []
            else:
                parts = line.split()
                tok, lab = parts[0], parts[-1]
                tokens.append(tok); labels.append(lab)
    if tokens: sents.append({'tokens': tokens, 'ner_tags': labels})
    return sents

train = read_conll(DATA_DIR/'train.conll')
dev   = read_conll(DATA_DIR/'dev.conll')
test  = read_conll(DATA_DIR/'test.conll')

print(f'Loaded: train={len(train)} dev={len(dev)} test={len(test)}')
print('Sample:', train[0]['tokens'][:12], '\n', train[0]['ner_tags'][:12])


Loaded: train=4 dev=200 test=851
Sample: ['He', 'had', 'a', 'medical', 'history', 'of', 'diabetes', 'mellitus', ',', 'hypertension', 'and', 'he'] 
 ['O', 'O', 'O', 'O', 'O', 'O', 'B-ety', 'I-ety', 'O', 'B-ety', 'O', 'O']


In [24]:
# [Fully Frozen] Labels, datasets, tokenizer + alignment
all_labels = sorted({l for ex in (train+dev+test) for l in ex['ner_tags']})
if 'O' in all_labels:
    all_labels.remove('O'); all_labels = ['O'] + all_labels
label2id = {l:i for i,l in enumerate(all_labels)}
id2label = {i:l for l,i in label2id.items()}
num_labels = len(all_labels)
print('Labels:', all_labels)

ds = DatasetDict({
    'train': Dataset.from_list(train),
    'validation': Dataset.from_list(dev),
    'test': Dataset.from_list(test),
})

MODEL_NAME = 'dmis-lab/biobert-base-cased-v1.1'
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize_align(batch):
    tokenized = tokenizer(batch['tokens'], is_split_into_words=True, truncation=True)
    labels = []
    for i, lbls in enumerate(batch['ner_tags']):
        word_ids = tokenized.word_ids(batch_index=i)
        aligned = []
        prev_word = None
        for wid in word_ids:
            if wid is None:
                aligned.append(-100)
            else:
                if wid != prev_word:
                    aligned.append(label2id.get(lbls[wid], label2id['O']))
                else:
                    aligned.append(-100)
                prev_word = wid
        labels.append(aligned)
    tokenized['labels'] = labels
    return tokenized

tokenized = ds.map(tokenize_align, batched=True, remove_columns=['tokens','ner_tags'])


Labels: ['O', 'B-ety', 'I-ety']


Map:   0%|          | 0/4 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/851 [00:00<?, ? examples/s]

In [25]:
# [Fully Frozen] Metrics (seqeval F1)
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(-1)
    pred_tags, true_tags = [], []
    for p, l in zip(preds, labels):
        pt, lt = [], []
        for pi, li in zip(p, l):
            if li == -100:
                continue
            pt.append(id2label[int(pi)])
            lt.append(id2label[int(li)])
        pred_tags.append(pt); true_tags.append(lt)
    return {'f1': f1_score(true_tags, pred_tags)}


In [26]:
# [Fully Frozen] Model definition (encoder frozen, Identity projection)
import torch
import torch.nn as nn
from transformers import AutoModel
from transformers.modeling_outputs import TokenClassifierOutput

class ProtoTokenClassifierFullyFrozen(nn.Module):
    def __init__(self, model_name, num_labels, id2label=None, label2id=None, dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        for p in self.encoder.parameters():
            p.requires_grad = False
        hidden = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(dropout)
        self.num_labels = num_labels
        self.id2label = id2label or {}
        self.label2id = label2id or {}
        self.proj = nn.Identity()
        self.register_buffer('static_prototypes', None, persistent=False)

    def set_static_prototypes(self, protos):
        self.static_prototypes = protos

    def _features(self, **inputs):
        out = self.encoder(**inputs)
        x = self.dropout(out.last_hidden_state)
        x = self.proj(x)
        return x

    @staticmethod
    def _pairwise_sq_dist(x, y):
        x2 = (x**2).sum(-1, keepdim=True)
        y2 = (y**2).sum(-1).unsqueeze(0)
        xy = x @ y.t()
        return x2 + y2 - 2*xy

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, prototypes=None):
        feats = self._features(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        B, T, H = feats.shape
        device = feats.device
        protos = prototypes
        if protos is None:
            if labels is not None:
                feats_flat = feats.reshape(-1, H)
                labels_flat = labels.view(-1)
                protos_list = []
                for c in range(self.num_labels):
                    mask_c = (labels_flat == c)
                    if mask_c.any():
                        pc = feats_flat[mask_c].mean(dim=0)
                    else:
                        if self.static_prototypes is not None:
                            pc = self.static_prototypes[c]
                        else:
                            pc = torch.zeros(H, device=device)
                    protos_list.append(pc)
                protos = torch.stack(protos_list, dim=0)
            elif self.static_prototypes is not None:
                protos = self.static_prototypes
        if protos is None:
            logits = torch.zeros(B, T, self.num_labels, device=device)
        else:
            x = feats.reshape(-1, H)
            dists = self._pairwise_sq_dist(x, protos)
            logits = (-dists).reshape(B, T, self.num_labels)
        return TokenClassifierOutput(logits=logits)

    @torch.no_grad()
    def compute_prototypes_from_dataset(self, dataloader, device, num_labels):
        sums = None
        counts = torch.zeros(num_labels, dtype=torch.long, device=device)
        for batch in dataloader:
            for k in list(batch.keys()):
                if isinstance(batch[k], torch.Tensor):
                    batch[k] = batch[k].to(device)
            labels = batch.get('labels', None)
            feats = self._features(input_ids=batch['input_ids'], attention_mask=batch.get('attention_mask'), token_type_ids=batch.get('token_type_ids'))
            B, T, H = feats.shape
            if sums is None:
                sums = torch.zeros(num_labels, H, device=device)
            feats_flat = feats.reshape(-1, H)
            if labels is None:
                continue
            labels_flat = labels.view(-1)
            for c in range(num_labels):
                mask_c = (labels_flat == c)
                if mask_c.any():
                    sums[c] += feats_flat[mask_c].sum(dim=0)
                    counts[c] += mask_c.sum()
        protos = torch.where(counts.view(-1,1) > 0, sums / counts.clamp(min=1).view(-1,1), torch.zeros_like(sums))
        return protos


In [27]:
# [Fully Frozen] ProtoTrainer (build static prototypes before eval/predict)
from transformers import Trainer
import torch.nn as nn

class ProtoTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights
        self._static_protos_ready = False

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop('labels')
        outputs = model(**inputs, labels=labels)
        logits = outputs.logits
        loss_fct = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=0.0)
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

    def _ensure_static_prototypes(self):
        if self._static_protos_ready:
            return
        dl = self.get_eval_dataloader(self.train_dataset)
        device = next(self.model.parameters()).device
        protos = self.model.compute_prototypes_from_dataset(dl, device, self.model.num_labels)
        self.model.set_static_prototypes(protos)
        self._static_protos_ready = True

    def evaluate(self, *args, **kwargs):
        self._ensure_static_prototypes()
        return super().evaluate(*args, **kwargs)

    def predict(self, *args, **kwargs):
        self._ensure_static_prototypes()
        return super().predict(*args, **kwargs)


In [28]:
# [Fully Frozen] Arguments, collator, model+trainer, and evaluation (no training)
import os, json
import re
os.environ['WANDB_MODE'] = 'disabled'

collator = DataCollatorForTokenClassification(tokenizer)

args = TrainingArguments(
    output_dir=str(OUT_DIR),
    do_train=False,
    do_eval=True,
    per_device_eval_batch_size=8,
    seed=42,
    report_to=None,
)

model = ProtoTokenClassifierFullyFrozen(MODEL_NAME, num_labels=num_labels, id2label=id2label, label2id=label2id)

trainer = ProtoTrainer(
    model=model,
    args=args,
    train_dataset=tokenized['train'],  # used to build static prototypes
    eval_dataset=tokenized['validation'],
    tokenizer=tokenizer,
    data_collator=collator,
    compute_metrics=compute_metrics,
)

dev_metrics  = trainer.evaluate(tokenized['validation'])
test_metrics = trainer.evaluate(tokenized['test'])

pred = trainer.predict(tokenized['test'])
pred_ids = pred.predictions.argmax(-1)
pred_tags, true_tags = [], []
for p, l in zip(pred_ids, tokenized['test']['labels']):
    pt, lt = [], []
    for pi, li in zip(p, l):
        if li == -100:
            continue
        pt.append(id2label[int(pi)])
        lt.append(id2label[int(li)])
    pred_tags.append(pt); true_tags.append(lt)

rep = classification_report(true_tags, pred_tags)

# Extract precision, recall from classification report for 'weighted avg'
precision_test, recall_test = 0.0, 0.0
lines = rep.split('\n')
for line in lines:
    if 'weighted avg' in line:
        match = re.search(r'weighted avg\s+(\d+\.\d+)\s+(\d+\.\d+)', line)
        if match:
            precision_test = float(match.group(1))
            recall_test = float(match.group(2))
        break

# Output the requested metrics with 4 decimal places
print(f'Validation F1: {dev_metrics["eval_f1"]:.4f}')
print(f'Test Precision: {precision_test:.4f}')
print(f'Test Recall: {recall_test:.4f}')
print(f'Test F1-score: {test_metrics["eval_f1"]:.4f}')
print(f'Test Inference Time (s): {test_metrics["eval_runtime"]:.4f}')

# Update metrics.json to include all requested metrics
with open(OUT_DIR/'metrics.json','w') as f:
    json.dump({
        'f1_dev': float(dev_metrics['eval_f1']),
        'f1_test': float(test_metrics['eval_f1']),
        'precision_test': precision_test,
        'recall_test': recall_test,
        'inference_time_test': float(test_metrics['eval_runtime'])
    }, f, indent=4)
print(f'Saved metrics to {OUT_DIR}/metrics.json')

with open(OUT_DIR/'classification_report_test.txt','w') as f:
    f.write(rep)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('Trainable params (FULL-FROZEN):', trainable_params)


  super().__init__(*args, **kwargs)


Validation F1: 0.6324
Test Precision: 0.4700
Test Recall: 0.9100
Test F1-score: 0.6190
Test Inference Time (s): 142.0920
Saved metrics to /content/drive/MyDrive/small_data_NER/results/proto_net_full_frozen_k5_full/metrics.json
Trainable params (FULL-FROZEN): 0
