# W2NER для русского клинического NER

## Извлечение медицинских сущностей из клинических рекомендаций Минздрава РФ

Исходник модели тут: https://github.com/ljynlp/W2NER


---
## Setup & Installation

In [None]:
import os, sys, glob, zipfile, pathlib, shutil, json, re
from types import SimpleNamespace
from collections import Counter, defaultdict

!pip install -q gensim

import numpy as np, pandas as pd
import torch
from transformers import AutoTokenizer
from torch.utils.data import DataLoader
from tqdm import tqdm

CONTENT = "/content"
REPO_ZIP = f"{CONTENT}/W2NER-main.zip"
REPO = f"{CONTENT}/W2NER-main"
DATA_GUIDE = f"{REPO}/data/guidelines_ru"
DATA_PRED = f"{REPO}/data/predict"
MODELS_DIR = f"{REPO}/models"
OUTPUTS_DIR = f"{REPO}/outputs"
CONFIG_PATH = f"{REPO}/config/guidelines_ru.json"

---
## Unpack

In [None]:
# Распаковываем W2NER-main.zip
with zipfile.ZipFile(REPO_ZIP, 'r') as z:
  z.extractall(CONTENT)

In [None]:
# Создаем все необходимые папки
for d in [DATA_GUIDE, DATA_PRED, MODELS_DIR, OUTPUTS_DIR, f"{REPO}/log"]:
  pathlib.Path(d).mkdir(parents=True, exist_ok=True)

In [None]:
# Patch 1: np.int -> int
p = f"{REPO}/data_loader.py"
txt = open(p, "r", encoding="utf-8").read()
if "np.int" in txt:
    txt = txt.replace("dtype=np.int", "dtype=int")
    open(p, "w", encoding="utf-8").write(txt)

# Patch 2: transformers.AdamW -> torch.optim.AdamW
p = f"{REPO}/main.py"
txt = open(p, "r", encoding="utf-8").read()

# Заменяем импорт
txt = txt.replace("from transformers import AdamW", "from torch.optim import AdamW")
txt = txt.replace("import transformers", "import transformers\nfrom torch.optim import AdamW")
txt = txt.replace("transformers.AdamW", "AdamW")

open(p, "w", encoding="utf-8").write(txt)


---
## Configuration

In [None]:
# config update
if os.path.exists(CONFIG_PATH):
    cfg = json.load(open(CONFIG_PATH, "r", encoding="utf-8"))
else:
    base_config = f"{REPO}/config/conll03.json"
    cfg = json.load(open(base_config, "r", encoding="utf-8"))
    cfg["bert_hid_size"] = 768
    cfg["lstm_hid_size"] = 512
    cfg["use_bert_last_4_layers"] = False

cfg["dataset"] = "guidelines_ru"
cfg["save_path"] = f"{MODELS_DIR}/guidelines_ru_model.pt"
cfg["predict_path"] = f"{OUTPUTS_DIR}/guidelines_ru_pred.json"
cfg["epochs"] = 10
cfg["bert_name"] = "bert-base-multilingual-cased"

json.dump(cfg, open(CONFIG_PATH, "w", encoding="utf-8"), ensure_ascii=False, indent=2)


---
## Data Preparation - Training Data

In [None]:
# copy training data
src_train = "/content/w2ner_train.json"
src_dev = "/content/w2ner_dev.json"
src_test = "/content/w2ner_test.json"

shutil.copy(src_train, f"{DATA_GUIDE}/train.json")
shutil.copy(src_dev, f"{DATA_GUIDE}/dev.json")
shutil.copy(src_test, f"{DATA_GUIDE}/test.json")


In [None]:
os.chdir(REPO)
!python main.py --config {CONFIG_PATH}

---
## Model & Data Copy

In [None]:
all_pt = glob.glob(f"{REPO}/**/*.pt", recursive=True)
if all_pt:
    trained_model = max(all_pt, key=os.path.getsize)
    final_path = f"{MODELS_DIR}/guidelines_ru_model.pt"
    shutil.copy(trained_model, final_path)
    model_path = final_path
else:
    model_path = cfg["save_path"]


---
## Подготовка к Inference

In [None]:
# построение ID2label - словаря сущностей
train_path = f"{DATA_GUIDE}/train.json"
train = json.load(open(train_path, "r", encoding="utf-8"))

unique_types = set()
for record in train:
    for entity in record.get('ner', []):
        etype = entity.get('type', '')
        if etype:
            unique_types.add(etype)

unique_types = sorted(unique_types)

ID2LABEL = {0: "O", 1: "NNW"}
for idx, etype in enumerate(unique_types, start=2):
    ID2LABEL[idx] = etype

---
## Model Uploading

In [None]:
os.chdir(REPO)
sys.path.insert(0, REPO)

import data_loader
import model as w2model
from config import Config
import utils

args = SimpleNamespace(config=CONFIG_PATH, device=0)
cfg = Config(args)
cfg.logger = utils.get_logger(cfg.dataset)
_ = data_loader.load_data_bert(cfg)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

net = w2model.Model(cfg).to(device)
net.load_state_dict(torch.load(model_path, map_location="cpu"), strict=True)
net.eval()

tokenizer = AutoTokenizer.from_pretrained(cfg.bert_name)

---
## Загрузка тестовых данных (MD файлы)

In [None]:
TEST_ZIP = f"{CONTENT}/minzdrav_test.zip"
TEST_DIR = f"{CONTENT}/minzdrav_test"

if os.path.exists(TEST_ZIP):
    with zipfile.ZipFile(TEST_ZIP, 'r') as z:
        z.extractall(CONTENT)

md_files = sorted(glob.glob(f"{TEST_DIR}/**/*.md", recursive=True))

if not md_files:
    md_files = sorted(glob.glob(f"{CONTENT}/minzdrav_test/**/*.md", recursive=True))
print(f"Найдено файлов: {len(md_files)}")

---
## Inference (Text processing)

In [None]:
MAX_TOKENS = 220
MAX_WP = 510

def strip_markdown(md: str) -> str:
    md = re.sub(r"```.*?```", " ", md, flags=re.S)
    md = re.sub(r"\|.*?\|", " ", md)
    md = re.sub(r"^#+\s*", "", md, flags=re.M)
    md = re.sub(r"[*_>`~]", " ", md)
    md = re.sub(r"\s+", " ", md)
    return md.strip()

sent_split = re.compile(r"(?<=[\.\!\?])\s+|[\n\r]+")
token_pat = re.compile(r"[A-Za-zА-Яа-яЁё]+|\d+(?:[.,]\d+)?|[%°×x/]+|[^\s]")

def wp_len(tokens):
    return len(tokenizer(tokens, is_split_into_words=True, add_special_tokens=True, truncation=False)["input_ids"])

def ensure_wp_limit(tokens):
    if not tokens:
        return []
    if wp_len(tokens) <= MAX_WP:
        return [tokens]
    if len(tokens) <= 1:
        return [tokens]
    mid = len(tokens) // 2
    return ensure_wp_limit(tokens[:mid]) + ensure_wp_limit(tokens[mid:])

predict_data_full = []
meta_full = []

for fp in md_files:
    filename = os.path.basename(fp)
    raw = open(fp, "r", encoding="utf-8", errors="ignore").read()
    txt = strip_markdown(raw)
    sents = [s.strip() for s in sent_split.split(txt) if s.strip()]

    for s in sents:
        toks = token_pat.findall(s)
        if len(toks) < 3:
            continue

        token_chunks = []
        if len(toks) > MAX_TOKENS:
            for k in range(0, len(toks), MAX_TOKENS):
                token_chunks.append((k//MAX_TOKENS, toks[k:k+MAX_TOKENS]))
        else:
            token_chunks.append((0, toks))

        for chunk_idx, chunk in token_chunks:
            sub_chunks = ensure_wp_limit(chunk)
            for sub_idx, sub in enumerate(sub_chunks):
                predict_data_full.append({"sentence": sub, "ner": []})
                meta_full.append({
                    "source_file": filename,
                    "sentence_text": " ".join(sub),
                    "orig_sentence": s,
                })

# Vocab из train
train_data = json.load(open(f"{DATA_GUIDE}/train.json", "r", encoding="utf-8"))
vocab = data_loader.Vocabulary()
for record in train_data:
    for entity in record.get('ner', []):
        label = entity.get('type', '')
        if label:
            vocab.add_label(label.lower())

# DataLoader
pred_dataset_full = data_loader.RelationDataset(
    *data_loader.process_bert(predict_data_full, tokenizer, vocab)
)

pred_loader_full = DataLoader(
    pred_dataset_full,
    batch_size=8,
    collate_fn=data_loader.collate_fn,
    shuffle=False,
    num_workers=0
)

# Inference
all_preds_full = []

with torch.no_grad():
    for batch in tqdm(pred_loader_full, desc="Inference"):
        batch = [x.to(device) if hasattr(x, "to") else x for x in batch]

        bert_inputs   = batch[0]
        grid_mask2d   = batch[2]
        pieces2word   = batch[3]
        dist_inputs   = batch[4]
        sent_length   = batch[5]

        outputs = net(bert_inputs, grid_mask2d, dist_inputs, pieces2word, sent_length)

        if isinstance(outputs, (tuple, list)):
            outputs = outputs[0]

        pred_grid = outputs.argmax(dim=-1).detach().cpu().numpy()
        all_preds_full.extend(pred_grid)

---
## Decoding

In [None]:
decoded_full = []
total_entities = 0

for idx in range(len(all_preds_full)):
    pred_grid = all_preds_full[idx]
    meta_item = meta_full[idx]
    tokens = predict_data_full[idx]["sentence"]

    entities = []
    L = pred_grid.shape[0]

    for i in range(L):
        for j in range(i, L):
            type_id = int(pred_grid[i, j])
            if type_id >= 2:
                token_indices = list(range(i, j + 1))
                if all(idx < len(tokens) for idx in token_indices):
                    entity_tokens = [tokens[idx] for idx in token_indices]
                    entity_text = " ".join(entity_tokens)
                    entities.append({
                        "type_id": type_id,
                        "type": ID2LABEL.get(type_id, f"UNKNOWN_{type_id}"),
                        "index": token_indices,
                        "text": entity_text
                    })

    total_entities += len(entities)

    decoded_full.append({
        "source_file": meta_item["source_file"],
        "sentence_text": meta_item["sentence_text"],
        "orig_sentence": meta_item["orig_sentence"],
        "sentence": tokens,
        "entities": entities
    })

# Сохранение
output_path = f"{OUTPUTS_DIR}/validation_results.json"
with open(output_path, 'w', encoding='utf-8') as f:
    json.dump(decoded_full, f, ensure_ascii=False, indent=2)

---
## Statistics

In [None]:
print(f"Предложений: {len(decoded_full):,}")
print(f"Сущностей: {total_entities:,}")

type_counts = Counter()
for rec in decoded_full:
    for ent in rec.get('entities', []):
        type_counts[ent.get('type', 'UNKNOWN')] += 1

print("\nПо типам:")
for etype, count in sorted(type_counts.items(), key=lambda x: -x[1]):
    pct = count / total_entities * 100 if total_entities > 0 else 0
    print(f"  {etype:<35} {count:>5} ({pct:>5.1f}%)")

by_file = defaultdict(lambda: {"sentences": 0, "entities": 0})
for rec in decoded_full:
    filename = rec["source_file"]
    by_file[filename]["sentences"] += 1
    by_file[filename]["entities"] += len(rec.get('entities', []))

print("\nПо файлам:")
for filename in sorted(by_file.keys()):
    stats = by_file[filename]
    avg = stats["entities"] / stats["sentences"] if stats["sentences"] > 0 else 0
    print(f"  {filename:<50} │ {stats['entities']:>4} сущ. │ {avg:.2f}")
























---
## Conclusion

**ЧТО РАБОТАЕТ:**

Основные типы (ACTION, ADVERSE_EVENT, DRUG, AGE_GROUP, FOLLOW_UP) - 97% всех сущностей
Модель стабильна на всех 9 файлах
F1 ≈ 92-95% судя по результатам

**ЧТО МОЖНО УЛУЧШИТЬ:**

- PROCEDURE_TEST - только 28 (в train было 482)
- LAB_VALUE - только 7 (в train было 110)
- SEVERITY/STAGE - только 1 (в train было 66)

ПРИЧИНА: Validation данные (про рак, лейкемию и др.) содержат другие типы сущностей, чем train данные (необходимо дальнейшее обучение).