In [None]:
!pip install seqeval -q
!pip install -U transformers

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
Collecting transformers
  Downloading transformers-4.57.0-py3-none-any.whl.metadata (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.4/41.4 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Downloading transformers-4.57.0-py3-none-any.whl (12.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.0/12.0 MB[0m [31m71.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: transformers
  Attempting uninstall: transformers
    Found existing installation: transformers 4.56.2
    Uninstalling transformers-4.56.2:
      Successfully uninstalled transformers-4.56.2
Successfully installed tran

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [42]:
# ===== Baseline: BioBERT fine-tuning on E3C few-shot =====
import os, random, numpy as np
from pathlib import Path
from datasets import Dataset, DatasetDict
from transformers import (AutoTokenizer, AutoModelForTokenClassification,
                          DataCollatorForTokenClassification, TrainingArguments, Trainer)
from seqeval.metrics import classification_report, f1_score

# ---- paths ----
BASE = Path("/content/drive/MyDrive/small_data_NER_project")
DATA_DIR = BASE/"conll/fewshot_k10_seed42_mention"   # <-- change to fewshot_k1_seed42 / k10 / k20 if needed
OUT_DIR  = BASE/"results"/"biobert_k5_full"

# ---- read CoNLL ----
def read_conll(path):
    sents, tokens, labels = [], [], []
    with open(path, encoding="utf-8") as f:
        for line in f:
            line=line.strip()
            if not line:
                if tokens:
                    sents.append({"tokens":tokens, "ner_tags":labels})
                    tokens, labels = [], []
            else:
                parts = line.split()
                tok, lab = parts[0], parts[-1]
                tokens.append(tok); labels.append(lab)
    if tokens: sents.append({"tokens":tokens, "ner_tags":labels})
    return sents

train = read_conll(DATA_DIR/"train.conll")
dev   = read_conll(DATA_DIR/"dev.conll")
test  = read_conll(DATA_DIR/"test.conll")

print(f"Loaded: train={len(train)} dev={len(dev)} test={len(test)}")
print("Sample:", train[0]["tokens"][:12], "\n", train[0]["ner_tags"][:12])

Loaded: train=2 dev=200 test=851
Sample: ['He', 'had', 'a', 'medical', 'history', 'of', 'diabetes', 'mellitus', ',', 'hypertension', 'and', 'he'] 
 ['O', 'O', 'O', 'O', 'O', 'O', 'B-ety', 'I-ety', 'O', 'B-ety', 'O', 'O']


In [43]:
# ---- build label list (BIO) ----
all_labels = sorted({l for ex in (train+dev+test) for l in ex["ner_tags"]})
if "O" in all_labels:
    all_labels.remove("O"); all_labels = ["O"] + all_labels
label2id = {l:i for i,l in enumerate(all_labels)}
id2label = {i:l for l,i in label2id.items()}
num_labels = len(all_labels)
print("Labels:", all_labels)

# ---- HF datasets ----
ds = DatasetDict({
    "train": Dataset.from_list(train),
    "validation": Dataset.from_list(dev),
    "test": Dataset.from_list(test),
})

# ---- tokenizer & alignment ----
MODEL_NAME = "dmis-lab/biobert-base-cased-v1.1"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def tokenize_align(batch):
    tokenized = tokenizer(batch["tokens"], is_split_into_words=True, truncation=True)
    labels = []
    for i, lbls in enumerate(batch["ner_tags"]):
        word_ids = tokenized.word_ids(batch_index=i)
        aligned = []
        prev_word = None
        for wid in word_ids:
            if wid is None:
                aligned.append(-100)
            else:
                # Only label the first wordpiece; rest -> -100
                if wid != prev_word:
                    aligned.append(label2id.get(lbls[wid], label2id["O"]))
                else:
                    aligned.append(-100)
                prev_word = wid
        labels.append(aligned)
    tokenized["labels"] = labels
    return tokenized

tokenized = ds.map(tokenize_align, batched=True, remove_columns=["tokens","ner_tags"])


Labels: ['O', 'B-ety', 'I-ety']


Map:   0%|          | 0/2 [00:00<?, ? examples/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


Map:   0%|          | 0/200 [00:00<?, ? examples/s]

Map:   0%|          | 0/851 [00:00<?, ? examples/s]

In [44]:
print("Labels:", all_labels)
assert all_labels[0] == "O"
assert set(all_labels) >= {"B-ety","I-ety"}  # 若只有一种实体类型
print("✓ label vocab OK")

Labels: ['O', 'B-ety', 'I-ety']
✓ label vocab OK


In [45]:
def effective_labels_count(tokenized_split):
    return sum(int(x!=-100) for ex in tokenized_split["labels"] for x in ex)

print("eff labels (train):", effective_labels_count(tokenized["train"]))
print("eff labels (dev)  :", effective_labels_count(tokenized["validation"]))
print("eff labels (test) :", effective_labels_count(tokenized["test"]))


eff labels (train): 58
eff labels (dev)  : 3545
eff labels (test) : 16702


In [46]:
import numpy as np

def unique_effective_ids(tokenized_split):
    ids = []
    for ex in tokenized_split["labels"]:
        ids.extend([i for i in ex if i!=-100])
    return sorted(set(ids))

uids_train = unique_effective_ids(tokenized["train"])
print("unique label ids (train):", uids_train, [id2label[i] for i in uids_train])

unique label ids (train): [0, 1, 2] ['O', 'B-ety', 'I-ety']


In [88]:
# ---- model ----
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME, num_labels=num_labels, id2label=id2label, label2id=label2id)

# ---- put a lower weight on "O" tags ----
from collections import Counter
import torch, numpy as np

cnt = Counter(l for ex in train for l in ex["ner_tags"])
weights = np.array([cnt.get(l,1) for l in all_labels], dtype=float)
weights = 1.0 / weights
weights /= weights.max()
weights[label2id['O']] *= 0.8
class_weights = torch.tensor(weights, dtype=torch.float)
print({l: float(class_weights[label2id[l]]) for l in all_labels})

# ---- metrics ----
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(-1)
    pred_tags, true_tags = [], []
    for p, l in zip(preds, labels):
        pt, lt = [], []
        for pi, li in zip(p, l):
            if li == -100:  # skip subword positions
                continue
            pt.append(id2label[int(pi)])
            lt.append(id2label[int(li)])
        pred_tags.append(pt); true_tags.append(lt)
    f1 = f1_score(true_tags, pred_tags)
    return {"f1": f1}


Some weights of BertForTokenClassification were not initialized from the model checkpoint at dmis-lab/biobert-base-cased-v1.1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


{'O': 0.11428571492433548, 'B-ety': 0.6000000238418579, 'I-ety': 1.0}


In [92]:
# ---- training args ----
from transformers import TrainingArguments

OUT_DIR.mkdir(parents=True, exist_ok=True)

args = TrainingArguments(
    output_dir=str(OUT_DIR),
    do_train=True,
    do_eval=True,                     # 保留评估功能
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    learning_rate=1e-5,
    num_train_epochs=50,              # 🔹 增加 epoch 数
    weight_decay=0.01,
    logging_dir=str(OUT_DIR / "logs"),
    logging_steps=10,                 # ✅ 每 10 步打印一次 loss
    save_steps=500,                   # 保留存档机制
    seed=42,
    report_to=None                    # 不上传日志到 wandb
)
collator = DataCollatorForTokenClassification(tokenizer)

In [93]:
# WeightedTrainer
from transformers import Trainer
import torch.nn as nn

class WeightedTrainer(Trainer):
    def __init__(self, *args, class_weights=None, **kwargs):
        # 不做 processing_class/tokenizer 的自动处理，避免冲突
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights  # 先保存在 CPU，compute_loss 时再搬设备

    # 兼容新版本会传入的额外参数
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits  # [B, T, C]

        # 每步把权重搬到 logits 所在设备，避免 CPU/CUDA 冲突
        w = self.class_weights.to(logits.device)
        loss_fct = nn.CrossEntropyLoss(
            weight=self.class_weights.to(logits.device),
            ignore_index=-100,
            label_smoothing=0.1  # 🟢 加上这行
        )

        loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))

        return (loss, outputs) if return_outputs else loss

In [94]:
import os
os.environ["WANDB_MODE"] = "disabled"
trainer = WeightedTrainer(
    model=model,
    args=args,
    train_dataset=tokenized["train"],
    eval_dataset=tokenized["validation"],
    tokenizer=tokenizer,          # 保留 tokenizer；FutureWarning 无碍
    data_collator=collator,
    compute_metrics=compute_metrics,
    class_weights=class_weights,  # 这里不需要 .to(model.device)
)
trainer.train()

  super().__init__(*args, **kwargs)


Step,Training Loss
10,0.6193
20,0.5139
30,0.491
40,0.4858
50,0.4846


TrainOutput(global_step=50, training_loss=0.5189181995391846, metrics={'train_runtime': 21.7269, 'train_samples_per_second': 4.603, 'train_steps_per_second': 2.301, 'total_flos': 2806924131000.0, 'train_loss': 0.5189181995391846, 'epoch': 50.0})

In [95]:
out = trainer.predict(tokenized["validation"])
import numpy as np
from collections import Counter

pred_ids = np.argmax(out.predictions, axis=-1)
true_ids = out.label_ids

pred_tags = []
for p, l in zip(pred_ids, true_ids):
    pred_tags += [id2label[int(pi)] for pi, li in zip(p, l) if li != -100]
print("Pred label dist on DEV:", Counter(pred_tags))

Pred label dist on DEV: Counter({'O': 2085, 'I-ety': 945, 'B-ety': 515})


In [96]:
true_tags = []
for l in true_ids:
    true_tags += [id2label[int(li)] for li in l if li != -100]
print("Gold label dist on DEV:", Counter(true_tags))

Gold label dist on DEV: Counter({'O': 3285, 'B-ety': 134, 'I-ety': 126})


In [97]:
!nvidia-smi
import torch; print("cuda?", torch.cuda.is_available())

Thu Oct  9 01:46:34 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   60C    P0             29W /   70W |    1980MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [98]:
# ---- evaluate (dev + test) ----
def eval_split(name):
    out = trainer.evaluate(tokenized[name])
    print(f"{name.upper()} F1:", round(out["eval_f1"], 4))
    return out["eval_f1"]

f1_dev  = eval_split("validation")
f1_test = eval_split("test")

# ---- save predictions + detailed report on test ----
pred_logits = trainer.predict(tokenized["test"]).predictions
pred_ids = pred_logits.argmax(-1)
pred_tags, true_tags = [], []
for p, l in zip(pred_ids, tokenized["test"]["labels"]):
    pt, lt = [], []
    for pi, li in zip(p, l):
        if li == -100:
            continue
        pt.append(id2label[int(pi)])
        lt.append(id2label[int(li)])
    pred_tags.append(pt); true_tags.append(lt)

print("\nClassification report (test):")
print(classification_report(true_tags, pred_tags))

# save minimal metrics
import json
with open(OUT_DIR/"metrics.json","w") as f:
    json.dump({"f1_dev": float(f1_dev), "f1_test": float(f1_test)}, f, indent=2)
print(f"\nSaved metrics to {OUT_DIR}/metrics.json")

VALIDATION F1: 0.1307
TEST F1: 0.1098

Classification report (test):
              precision    recall  f1-score   support

         ety       0.06      0.53      0.11       516

   micro avg       0.06      0.53      0.11       516
   macro avg       0.06      0.53      0.11       516
weighted avg       0.06      0.53      0.11       516


Saved metrics to /content/drive/MyDrive/small_data_NER_project/results/biobert_k5_full/metrics.json
