In [2]:
!pip install seqeval -q
!pip install -U transformers


  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone


In [5]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [11]:
# ===== Prototypical Networks: Token-level few-shot NER (BioBERT encoder) =====
import os, random, numpy as np
from pathlib import Path
from datasets import Dataset, DatasetDict
from transformers import (AutoTokenizer, AutoModel,
                          DataCollatorForTokenClassification, TrainingArguments, Trainer)
from seqeval.metrics import classification_report, f1_score

# ---- paths ----
BASE = Path("/content/drive/MyDrive/small_data_NER")
DATA_DIR = BASE / "conll/fewshot_k10_seed42_mention"   # <-- change to fewshot_k1_seed42 / k10 / k20 if needed
OUT_DIR  = BASE / "results"/"proto_net_baseline_k5_full"

# ---- read CoNLL ----
def read_conll(path):
    sents, tokens, labels = [], [], []
    with open(path, encoding="utf-8") as f:
        for line in f:
            line=line.strip()
            if not line:
                if tokens:
                    sents.append({"tokens":tokens, "ner_tags":labels})
                    tokens, labels = [], []
            else:
                parts = line.split()
                tok, lab = parts[0], parts[-1]
                tokens.append(tok); labels.append(lab)
    if tokens: sents.append({"tokens":tokens, "ner_tags":labels})
    return sents

train = read_conll(DATA_DIR/"train.conll")
dev   = read_conll(DATA_DIR/"dev.conll")
test  = read_conll(DATA_DIR/"test.conll")

print(f"Loaded: train={len(train)} dev={len(dev)} test={len(test)}")
print("Sample:", train[0]["tokens"][:12], "\n", train[0]["ner_tags"][:12])


Loaded: train=2 dev=200 test=851
Sample: ['He', 'had', 'a', 'medical', 'history', 'of', 'diabetes', 'mellitus', ',', 'hypertension', 'and', 'he'] 
 ['O', 'O', 'O', 'O', 'O', 'O', 'B-ety', 'I-ety', 'O', 'B-ety', 'O', 'O']


In [12]:
# ---- build label list (BIO) ----
all_labels = sorted({l for ex in (train+dev+test) for l in ex["ner_tags"]})
if "O" in all_labels:
    all_labels.remove("O"); all_labels = ["O"] + all_labels
label2id = {l:i for i,l in enumerate(all_labels)}
id2label = {i:l for l,i in label2id.items()}
num_labels = len(all_labels)
print("Labels:", all_labels)

# ---- HF datasets ----
ds = DatasetDict({
    "train": Dataset.from_list(train),
    "validation": Dataset.from_list(dev),
    "test": Dataset.from_list(test),
})

# ---- tokenizer & alignment ----
MODEL_NAME = "dmis-lab/biobert-base-cased-v1.1"  # encoder backbone
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize_align(batch):
    tokenized = tokenizer(batch["tokens"], is_split_into_words=True, truncation=True)
    labels = []
    for i, lbls in enumerate(batch["ner_tags"]):
        word_ids = tokenized.word_ids(batch_index=i)
        aligned = []
        prev_word = None
        for wid in word_ids:
            if wid is None:
                aligned.append(-100)
            else:
                # Only label the first wordpiece; rest -> -100
                if wid != prev_word:
                    aligned.append(label2id.get(lbls[wid], label2id["O"]))
                else:
                    aligned.append(-100)
                prev_word = wid
        labels.append(aligned)
    tokenized["labels"] = labels
    return tokenized

tokenized = ds.map(tokenize_align, batched=True, remove_columns=["tokens","ner_tags"])


Labels: ['O', 'B-ety', 'I-ety']


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/313 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/851 [00:00<?, ? examples/s]

In [13]:
print("Labels:", all_labels)
assert all_labels[0] == "O"
assert set(all_labels) >= {"B-ety","I-ety"}  # 若只有一种实体类型
print("✓ label vocab OK")


Labels: ['O', 'B-ety', 'I-ety']
✓ label vocab OK


In [14]:
def effective_labels_count(tokenized_split):
    return sum(int(x!=-100) for ex in tokenized_split["labels"] for x in ex)

print("eff labels (train):", effective_labels_count(tokenized["train"]))
print("eff labels (dev)  :", effective_labels_count(tokenized["validation"]))
print("eff labels (test) :", effective_labels_count(tokenized["test"]))


eff labels (train): 58
eff labels (dev)  : 3545
eff labels (test) : 16702


In [15]:
import numpy as np

def unique_effective_ids(tokenized_split):
    ids = []
    for ex in tokenized_split["labels"]:
        ids.extend([i for i in ex if i!=-100])
    return sorted(set(ids))

uids_train = unique_effective_ids(tokenized["train"])
print("unique label ids (train):", uids_train, [id2label[i] for i in uids_train])


unique label ids (train): [0, 1, 2] ['O', 'B-ety', 'I-ety']


In [17]:
# ---- class weights (reduce 'O') ----
from collections import Counter
import torch, numpy as np

cnt = Counter(l for ex in train for l in ex["ner_tags"])
weights = np.array([cnt.get(l,1) for l in all_labels], dtype=float)
weights = 1.0 / weights
weights /= weights.max()
weights[label2id['O']] *= 0.8
class_weights = torch.tensor(weights, dtype=torch.float)
print({l: float(class_weights[label2id[l]]) for l in all_labels})

# ---- metrics ----
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(-1)
    pred_tags, true_tags = [], []
    for p, l in zip(preds, labels):
        pt, lt = [], []
        for pi, li in zip(p, l):
            if li == -100:  # skip subword positions
                continue
            pt.append(id2label[int(pi)])
            lt.append(id2label[int(li)])
        pred_tags.append(pt); true_tags.append(lt)
    f1 = f1_score(true_tags, pred_tags)
    return {"f1": f1}


{'O': 0.11428571492433548, 'B-ety': 0.6000000238418579, 'I-ety': 1.0}


In [None]:
# ---- Prototypical Token Classifier (Frozen encoder + trainable projection) ----
import torch
import torch.nn as nn
from transformers import AutoModel
from transformers.modeling_outputs import TokenClassifierOutput

class ProtoTokenClassifierFrozen(nn.Module):
    def __init__(self, model_name, num_labels, id2label=None, label2id=None, proj_dim=None, dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        # freeze encoder parameters
        for p in self.encoder.parameters():
            p.requires_grad = False
        hidden = self.encoder.config.hidden_size
        self.dropout = nn.Dropout(dropout)
        self.num_labels = num_labels
        self.id2label = id2label or {}
        self.label2id = label2id or {}
        # a small projection head; default proj_dim=hidden
        dim = proj_dim or hidden
        self.proj = nn.Sequential(nn.Linear(hidden, dim), nn.Tanh(), nn.LayerNorm(dim))
        # static prototypes for eval/inference (C, D)
        self.register_buffer("static_prototypes", None, persistent=False)

    def set_static_prototypes(self, protos):
        self.static_prototypes = protos

    def _features(self, **inputs):
        out = self.encoder(**inputs)
        x = self.dropout(out.last_hidden_state)
        x = self.proj(x)  # [B,T,D]
        return x

    @staticmethod
    def _pairwise_sq_dist(x, y):
        x2 = (x**2).sum(-1, keepdim=True)
        y2 = (y**2).sum(-1).unsqueeze(0)
        xy = x @ y.t()
        return x2 + y2 - 2*xy

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None, prototypes=None):
        feats = self._features(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        B, T, D = feats.shape
        device = feats.device
        protos = prototypes
        if protos is None:
            if labels is not None:
                feats_flat = feats.reshape(-1, D)
                labels_flat = labels.view(-1)
                protos_list = []
                for c in range(self.num_labels):
                    mask_c = (labels_flat == c)
                    if mask_c.any():
                        pc = feats_flat[mask_c].mean(dim=0)
                    else:
                        if self.static_prototypes is not None:
                            pc = self.static_prototypes[c]
                        else:
                            pc = torch.zeros(D, device=device)
                    protos_list.append(pc)
                protos = torch.stack(protos_list, dim=0)  # [C,D]
            elif self.static_prototypes is not None:
                protos = self.static_prototypes
        if protos is None:
            logits = torch.zeros(B, T, self.num_labels, device=device)
        else:
            x = feats.reshape(-1, D)
            dists = self._pairwise_sq_dist(x, protos)
            logits = (-dists).reshape(B, T, self.num_labels)
        return TokenClassifierOutput(logits=logits)

    @torch.no_grad()
    def compute_prototypes_from_dataset(self, dataloader, device, num_labels):
        sums = None
        counts = torch.zeros(num_labels, dtype=torch.long, device=device)
        for batch in dataloader:
            for k in list(batch.keys()):
                if isinstance(batch[k], torch.Tensor):
                    batch[k] = batch[k].to(device)
            labels = batch.get('labels', None)
            feats = self._features(input_ids=batch['input_ids'], attention_mask=batch.get('attention_mask'), token_type_ids=batch.get('token_type_ids'))
            B, T, D = feats.shape
            if sums is None:
                sums = torch.zeros(num_labels, D, device=device)
            feats_flat = feats.reshape(-1, D)
            if labels is None:
                continue
            labels_flat = labels.view(-1)
            for c in range(num_labels):
                mask_c = (labels_flat == c)
                if mask_c.any():
                    sums[c] += feats_flat[mask_c].sum(dim=0)
                    counts[c] += mask_c.sum()
        protos = torch.where(counts.view(-1,1) > 0, sums / counts.clamp(min=1).view(-1,1), torch.zeros_like(sums))
        return protos


In [19]:
# ---- training args ----
from transformers import TrainingArguments

OUT_DIR.mkdir(parents=True, exist_ok=True)

args = TrainingArguments(
    output_dir=str(OUT_DIR),
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=1e-5,
    num_train_epochs=50,
    weight_decay=0.01,
    logging_dir=str(OUT_DIR / "logs"),
    logging_steps=10,
    save_steps=500,
    seed=42,
    report_to=None
)
collator = DataCollatorForTokenClassification(tokenizer)


In [27]:
# ---- ProtoTrainer with weighted CE over -distance logits ----
from transformers import Trainer
import torch.nn as nn

class ProtoTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights
        self._static_protos_ready = False

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop('labels')
        outputs = model(**inputs, labels=labels)  # in-batch prototypes
        logits = outputs.logits
        w = None
        if self.class_weights is not None:
            w = self.class_weights.to(logits.device)
        loss_fct = nn.CrossEntropyLoss(weight=w, ignore_index=-100, label_smoothing=0.1)
        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

    def _ensure_static_prototypes(self):
        if self._static_protos_ready:
            return
        # build dataloader over train set
        dl = self.get_eval_dataloader(self.train_dataset)
        # Get device from model parameters
        device = next(self.model.parameters()).device
        protos = self.model.compute_prototypes_from_dataset(dl, device, num_labels)
        self.model.set_static_prototypes(protos)
        self._static_protos_ready = True

    def evaluate(self, *args, **kwargs):
        self._ensure_static_prototypes()
        return super().evaluate(*args, **kwargs)

    def predict(self, *args, **kwargs):
        self._ensure_static_prototypes()
        return super().predict(*args, **kwargs)

In [None]:
# ---- model + trainer ----
import os
os.environ["WANDB_MODE"] = "disabled"
model = ProtoTokenClassifierFrozen(MODEL_NAME, num_labels=num_labels, id2label=id2label, label2id=label2id, proj_dim=None)
trainer = ProtoTrainer(
    model=model, args=args,
    train_dataset=tokenized["train"], eval_dataset=tokenized["validation"],
    tokenizer=tokenizer, data_collator=collator, compute_metrics=compute_metrics,
    class_weights=class_weights,
)
trainer.train()


In [29]:
# dev label distribution diagnostics
out = trainer.predict(tokenized["validation"])
import numpy as np
from collections import Counter

pred_ids = np.argmax(out.predictions, axis=-1)
true_ids = out.label_ids

pred_tags = []
for p, l in zip(pred_ids, true_ids):
    pred_tags += [id2label[int(pi)] for pi, li in zip(p, l) if li != -100]
print("Pred label dist on DEV:", Counter(pred_tags))




Pred label dist on DEV: Counter({'O': 3095, 'B-ety': 228, 'I-ety': 222})


In [30]:
true_tags = []
for l in true_ids:
    true_tags += [id2label[int(li)] for li in l if li != -100]
print("Gold label dist on DEV:", Counter(true_tags))


Gold label dist on DEV: Counter({'O': 3285, 'B-ety': 134, 'I-ety': 126})


In [31]:
!nvidia-smi
import torch; print("cuda?", torch.cuda.is_available())


/bin/bash: line 1: nvidia-smi: command not found
cuda? False


In [32]:
# ---- evaluate (dev + test) ----
import time
def eval_split(name):
    out = trainer.evaluate(tokenized[name])
    print(f"{name.upper()} F1:", round(out["eval_f1"], 4))
    return out["eval_f1"]

f1_dev  = eval_split("validation")
f1_test = eval_split("test")

# ---- save predictions + detailed report on test ----
t0 = time.time()
pred = trainer.predict(tokenized["test"])
inference_time = time.time() - t0
pred_logits = pred.predictions
pred_ids = pred_logits.argmax(-1)
pred_tags, true_tags = [], []
for p, l in zip(pred_ids, tokenized["test"]["labels"]):
    pt, lt = [], []
    for pi, li in zip(p, l):
        if li == -100:
            continue
        pt.append(id2label[int(pi)])
        lt.append(id2label[int(li)])
    pred_tags.append(pt); true_tags.append(lt)

print("\nClassification report (test):")
print(classification_report(true_tags, pred_tags))
print(f"Inference Time (s): {inference_time:.4f}")

# save minimal metrics
import json
with open(OUT_DIR/"metrics.json","w") as f:
    json.dump({"f1_dev": float(f1_dev), "f1_test": float(f1_test), "inference_time_s": float(inference_time)}, f, indent=2)
print(f"\nSaved metrics to {OUT_DIR}/metrics.json")


VALIDATION F1: 0.5097
TEST F1: 0.559

Classification report (test):
              precision    recall  f1-score   support

         ety       0.41      0.88      0.56       516

   micro avg       0.41      0.88      0.56       516
   macro avg       0.41      0.88      0.56       516
weighted avg       0.41      0.88      0.56       516


Saved metrics to /content/drive/MyDrive/small_data_NER/results/proto_net_k5_full/metrics.json
