In [1]:
import os, glob
os.system("git clone https://github.com/VinAIResearch/PhoNER_COVID19.git")
os.chdir("PhoNER_COVID19")

print("cwd:", os.getcwd())
print("word files:", glob.glob("data/word/*"))


cwd: /content/PhoNER_COVID19
word files: ['data/word/dev_word.json', 'data/word/train_word.json', 'data/word/dev_word.conll', 'data/word/test_word.json', 'data/word/test_word.conll', 'data/word/train_word.conll']


In [2]:
def read_conll(path):
    sents_tokens, sents_tags = [], []
    tokens, tags = [], []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                if tokens:
                    sents_tokens.append(tokens)
                    sents_tags.append(tags)
                    tokens, tags = [], []
                continue
            parts = line.split()
            # thường là: token tag (2 cột)
            token = parts[0]
            tag = parts[-1]
            tokens.append(token)
            tags.append(tag)
    if tokens:
        sents_tokens.append(tokens)
        sents_tags.append(tags)
    return sents_tokens, sents_tags

train_tokens, train_tags = read_conll("data/word/train_word.conll")
dev_tokens, dev_tags     = read_conll("data/word/dev_word.conll")
test_tokens, test_tags   = read_conll("data/word/test_word.conll")

len(train_tokens), len(dev_tokens), len(test_tokens)


(5027, 2000, 3000)

In [3]:
i = 0
print(train_tokens[i][:30])
print(train_tags[i][:30])


['Đồng_thời', ',', 'bệnh_viện', 'tiếp_tục', 'thực_hiện', 'các', 'biện_pháp', 'phòng_chống', 'dịch_bệnh', 'COVID', '-', '19', 'theo', 'hướng_dẫn', 'của', 'Bộ', 'Y_tế', '.']
['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'B-ORGANIZATION', 'I-ORGANIZATION', 'O']


In [4]:
from collections import Counter

PAD = "<pad>"
UNK = "<unk>"

counter = Counter()
for sent in train_tokens:
    counter.update(sent)

idx2word = [PAD, UNK] + list(counter.keys())
word2idx = {w:i for i,w in enumerate(idx2word)}
vocab_size = len(idx2word)

all_tags = sorted({t for sent in train_tags for t in sent})
tag2idx = {t:i for i,t in enumerate(all_tags)}
idx2tag = {i:t for t,i in tag2idx.items()}
num_tags = len(all_tags)

vocab_size, num_tags, all_tags[:15]


(5243,
 20,
 ['B-AGE',
  'B-DATE',
  'B-GENDER',
  'B-JOB',
  'B-LOCATION',
  'B-NAME',
  'B-ORGANIZATION',
  'B-PATIENT_ID',
  'B-SYMPTOM_AND_DISEASE',
  'B-TRANSPORTATION',
  'I-AGE',
  'I-DATE',
  'I-JOB',
  'I-LOCATION',
  'I-NAME'])

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader

MAX_LEN = 160

class NERDataset(Dataset):
    def __init__(self, sents_tokens, sents_tags, word2idx, tag2idx, max_len=160):
        self.X = sents_tokens
        self.Y = sents_tags
        self.word2idx = word2idx
        self.tag2idx = tag2idx
        self.max_len = max_len

    def __len__(self):
        return len(self.X)

    def __getitem__(self, i):
        tokens = self.X[i][:self.max_len]
        tags = self.Y[i][:self.max_len]
        ids = [self.word2idx.get(w, self.word2idx[UNK]) for w in tokens]
        y = [self.tag2idx[t] for t in tags]
        return ids, y

def collate_fn(batch):
    pad_id = word2idx[PAD]
    pad_tag = -100  # ignore_index cho CrossEntropyLoss

    maxl = max(len(x[0]) for x in batch)
    input_ids, attn_mask, labels = [], [], []

    for ids, y in batch:
        pad_len = maxl - len(ids)
        input_ids.append(ids + [pad_id]*pad_len)
        attn_mask.append([1]*len(ids) + [0]*pad_len)
        labels.append(y + [pad_tag]*pad_len)

    return (
        torch.tensor(input_ids, dtype=torch.long),
        torch.tensor(attn_mask, dtype=torch.long),
        torch.tensor(labels, dtype=torch.long),
    )


In [6]:
BATCH_SIZE = 32

train_loader = DataLoader(NERDataset(train_tokens, train_tags, word2idx, tag2idx, MAX_LEN),
                          batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
dev_loader   = DataLoader(NERDataset(dev_tokens, dev_tags, word2idx, tag2idx, MAX_LEN),
                          batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(NERDataset(test_tokens, test_tags, word2idx, tag2idx, MAX_LEN),
                          batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

len(train_loader), len(dev_loader), len(test_loader)


(158, 63, 94)

In [7]:
import math
import torch.nn as nn

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=512):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))  # [1,T,D]

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


In [8]:
import torch.nn.functional as F

class TransformerEncoderTagger(nn.Module):
    def __init__(self, vocab_size, num_tags, pad_id,
                 d_model=256, nhead=8, dim_ff=1024, num_layers=3, dropout=0.1):
        super().__init__()
        self.d_model = d_model
        self.pad_id = pad_id

        self.emb = nn.Embedding(vocab_size, d_model, padding_idx=pad_id)
        self.pos = PositionalEncoding(d_model, dropout=dropout, max_len=512)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_ff,
            dropout=dropout,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
        self.classifier = nn.Linear(d_model, num_tags)

    def forward(self, input_ids, attn_mask):
        # input_ids: [B,T], attn_mask: [B,T]
        x = self.emb(input_ids) * math.sqrt(self.d_model)
        x = self.pos(x)

        src_key_padding_mask = (attn_mask == 0)  # True ở PAD
        x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)  # [B,T,D]
        logits = self.classifier(x)  # [B,T,C]
        return logits


In [9]:
import torch

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TransformerEncoderTagger(
    vocab_size=vocab_size,
    num_tags=num_tags,
    pad_id=word2idx[PAD],
    d_model=256,
    nhead=8,
    dim_ff=1024,
    num_layers=3,   # đúng yêu cầu 3 lớp
    dropout=0.1
).to(DEVICE)

sum(p.numel() for p in model.parameters())/1e6


3.716628

In [10]:
LR = 3e-4
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

def bio_to_spans(tags):
    # tags: list[str] BIO (ví dụ B-LOC, I-LOC, O)
    spans = []
    start, ent_type = None, None
    for i, t in enumerate(tags + ["O"]):  # sentinel
        if t == "O" or t == "PAD":
            if start is not None:
                spans.append((start, i-1, ent_type))
                start, ent_type = None, None
            continue
        if t.startswith("B-"):
            if start is not None:
                spans.append((start, i-1, ent_type))
            start = i
            ent_type = t[2:]
        elif t.startswith("I-"):
            typ = t[2:]
            if start is None or typ != ent_type:
                # I- bị lỗi -> coi như B-
                if start is not None:
                    spans.append((start, i-1, ent_type))
                start = i
                ent_type = typ
    return spans

@torch.no_grad()
def eval_entity_f1(model, loader):
    model.eval()
    tp = fp = fn = 0
    token_correct = token_total = 0
    total_loss = 0.0
    n_samples = 0

    for input_ids, attn_mask, labels in loader:
        input_ids = input_ids.to(DEVICE)
        attn_mask = attn_mask.to(DEVICE)
        labels = labels.to(DEVICE)

        logits = model(input_ids, attn_mask)
        loss = criterion(logits.view(-1, num_tags), labels.view(-1))
        total_loss += loss.item() * input_ids.size(0)
        n_samples += input_ids.size(0)

        preds = logits.argmax(-1)  # [B,T]
        for b in range(input_ids.size(0)):
            # lấy phần token thật (attn_mask=1)
            L = int(attn_mask[b].sum().item())
            gold_ids = labels[b][:L].tolist()
            pred_ids = preds[b][:L].tolist()

            gold_tags = [idx2tag[i] for i in gold_ids]
            pred_tags = [idx2tag[i] for i in pred_ids]

            # token-acc
            token_correct += sum(g==p for g,p in zip(gold_ids, pred_ids))
            token_total += L

            gold_spans = set(bio_to_spans(gold_tags))
            pred_spans = set(bio_to_spans(pred_tags))

            tp += len(gold_spans & pred_spans)
            fp += len(pred_spans - gold_spans)
            fn += len(gold_spans - pred_spans)

    precision = tp / (tp + fp + 1e-9)
    recall    = tp / (tp + fn + 1e-9)
    f1        = 2*precision*recall / (precision + recall + 1e-9)
    token_acc = token_correct / max(1, token_total)
    return total_loss/max(1,n_samples), token_acc, precision, recall, f1


In [11]:
def train_one_epoch(model, loader):
    model.train()
    total_loss = 0.0
    n_samples = 0

    for input_ids, attn_mask, labels in loader:
        input_ids = input_ids.to(DEVICE)
        attn_mask = attn_mask.to(DEVICE)
        labels = labels.to(DEVICE)

        optimizer.zero_grad()
        logits = model(input_ids, attn_mask)
        loss = criterion(logits.view(-1, num_tags), labels.view(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item() * input_ids.size(0)
        n_samples += input_ids.size(0)

    return total_loss / max(1, n_samples)


In [12]:
EPOCHS = 30
PATIENCE = 3
best_path = "best_ner_encoder3.pt"

best_f1 = -1.0
patience_left = PATIENCE

for epoch in range(1, EPOCHS+1):
    tr_loss = train_one_epoch(model, train_loader)
    dev_loss, dev_tok_acc, dev_p, dev_r, dev_f1 = eval_entity_f1(model, dev_loader)

    improved = dev_f1 > best_f1 + 1e-4

    print(
        f"Epoch {epoch:02d} | "
        f"train_loss={tr_loss:.4f} | dev_loss={dev_loss:.4f} | "
        f"dev_tok_acc={dev_tok_acc:.4f} | dev_P={dev_p:.4f} | dev_R={dev_r:.4f} | dev_F1={dev_f1:.4f} | "
        f"{'improved' if improved else 'no-improve'} | patience_left={patience_left}"
    )

    if improved:
        best_f1 = dev_f1
        patience_left = PATIENCE
        torch.save(model.state_dict(), best_path)
    else:
        patience_left -= 1
        if patience_left == 0:
            print("Early stopping triggered!")
            break

print("Best dev F1:", best_f1)
print("Saved:", best_path)


  output = torch._nested_tensor_from_mask(


Epoch 01 | train_loss=0.6185 | dev_loss=0.5051 | dev_tok_acc=0.8538 | dev_P=0.3762 | dev_R=0.4194 | dev_F1=0.3966 | improved | patience_left=3
Epoch 02 | train_loss=0.3329 | dev_loss=0.3945 | dev_tok_acc=0.8827 | dev_P=0.4879 | dev_R=0.5516 | dev_F1=0.5178 | improved | patience_left=3
Epoch 03 | train_loss=0.2562 | dev_loss=0.3594 | dev_tok_acc=0.8942 | dev_P=0.5174 | dev_R=0.6253 | dev_F1=0.5662 | improved | patience_left=3
Epoch 04 | train_loss=0.2132 | dev_loss=0.3384 | dev_tok_acc=0.8990 | dev_P=0.5562 | dev_R=0.6999 | dev_F1=0.6198 | improved | patience_left=3
Epoch 05 | train_loss=0.1818 | dev_loss=0.3519 | dev_tok_acc=0.9055 | dev_P=0.5812 | dev_R=0.6539 | dev_F1=0.6154 | no-improve | patience_left=3
Epoch 06 | train_loss=0.1626 | dev_loss=0.3219 | dev_tok_acc=0.9095 | dev_P=0.6007 | dev_R=0.6906 | dev_F1=0.6425 | improved | patience_left=2
Epoch 07 | train_loss=0.1456 | dev_loss=0.3220 | dev_tok_acc=0.9085 | dev_P=0.5807 | dev_R=0.7133 | dev_F1=0.6402 | no-improve | patience_le

In [13]:
model.load_state_dict(torch.load(best_path, map_location=DEVICE))
test_loss, test_tok_acc, test_p, test_r, test_f1 = eval_entity_f1(model, test_loader)

print("TEST loss:", test_loss)
print("TEST token-acc:", test_tok_acc)
print("TEST P/R/F1:", test_p, test_r, test_f1)


TEST loss: 0.40404242599010465
TEST token-acc: 0.9111419767496148
TEST P/R/F1: 0.6255040706078806 0.70055389859389 0.6609052169627591


In [14]:
@torch.no_grad()
def eval_entity_f1_per_type(model, loader):
    model.eval()
    stats = {}  # type -> tp, fp, fn

    for input_ids, attn_mask, labels in loader:
        input_ids = input_ids.to(DEVICE)
        attn_mask = attn_mask.to(DEVICE)
        labels = labels.to(DEVICE)

        logits = model(input_ids, attn_mask)
        preds = logits.argmax(-1)

        for b in range(input_ids.size(0)):
            L = int(attn_mask[b].sum().item())
            gold_tags = [idx2tag[i] for i in labels[b][:L].tolist()]
            pred_tags = [idx2tag[i] for i in preds[b][:L].tolist()]

            gold_spans = bio_to_spans(gold_tags)
            pred_spans = bio_to_spans(pred_tags)

            gold_by_type = {}
            pred_by_type = {}
            for s,e,t in gold_spans:
                gold_by_type.setdefault(t,set()).add((s,e))
            for s,e,t in pred_spans:
                pred_by_type.setdefault(t,set()).add((s,e))

            all_types = set(gold_by_type) | set(pred_by_type)
            for t in all_types:
                g = gold_by_type.get(t,set())
                p = pred_by_type.get(t,set())
                tp = len(g & p)
                fp = len(p - g)
                fn = len(g - p)
                if t not in stats:
                    stats[t] = [0,0,0]
                stats[t][0] += tp
                stats[t][1] += fp
                stats[t][2] += fn

    # print
    print("="*72)
    print(f"{'Type':<15} {'Precision':>10} {'Recall':>10} {'F1':>10} {'Support':>10}")
    print("-"*72)

    for t in sorted(stats.keys()):
        tp, fp, fn = stats[t]
        p = tp / (tp + fp + 1e-9)
        r = tp / (tp + fn + 1e-9)
        f1 = 2*p*r / (p+r+1e-9)
        support = tp + fn
        print(f"{t:<15} {p:10.4f} {r:10.4f} {f1:10.4f} {support:10d}")

    print("="*72)

eval_entity_f1_per_type(model, test_loader)


Type             Precision     Recall         F1    Support
------------------------------------------------------------------------
AGE                 0.7457     0.8866     0.8100        582
DATE                0.5630     0.7648     0.6486       1654
GENDER              0.8788     0.9416     0.9091        462
JOB                 0.4659     0.4740     0.4699        173
LOCATION            0.5991     0.6609     0.6285       4441
NAME                0.8384     0.5220     0.6434        318
ORGANIZATION        0.3590     0.5979     0.4487        771
PATIENT_ID          0.8526     0.7676     0.8079       2005
SYMPTOM_AND_DISEASE     0.5951     0.6417     0.6175       1136
TRANSPORTATION      0.7623     0.4819     0.5905        193
