In [4]:
# ШАГ 0. Установка зависимостей (запустить ОДИН раз)
%pip install -U torch torchvision torchaudio \
             transformers==4.41.0 datasets seqeval accelerate \
             snorkel networkx pymupdf tqdm rapidfuzz


Note: you may need to restart the kernel to use updated packages.


In [5]:
# ШАГ 1. Системные импорты и базовые настройки
import os, re, json, glob, logging, random, itertools
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import pandas as pd
import multiprocessing as mp

import fitz                           # PyMuPDF
from rapidfuzz import fuzz, process   # быстрые строчные сопоставления

from transformers import (AutoTokenizer, AutoModelForTokenClassification,
                          DataCollatorForTokenClassification, TrainingArguments,
                          Trainer, pipeline, AutoConfig)
from datasets import Dataset, DatasetDict, ClassLabel
from seqeval.metrics import classification_report
import torch, nltk

nltk.download('punkt')
nltk.download('wordnet')

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Общая конфигурация
class CFG:
    pdf_dir        = "../data"               # папка с входными PDF
    weak_label_dir = "weak_labels"           # куда сохранить промежуточную разметку
    model_ckpt     = "m3rg-iitd/matscibert"  # encoder-only доменная модель
    max_len        = 192
    num_epochs     = 4
    lr             = 3e-5
    batch_size     = 8
cfg = CFG()


[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\knyaz_ayotgwn\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\knyaz_ayotgwn\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


# Тест модели


In [6]:
# ШАГ 2. Извлечение “чистого” текста (ENG only, без формул/рисунков)
from nltk.tokenize import sent_tokenize, word_tokenize

non_eng_pattern = re.compile(r'[А-Яа-яЁё]+')
latex_pattern   = re.compile(r'\$[^$]*\$|\\\[.*?\\\]|\\begin\{.*?}', re.S)

def clean_text(page_txt:str)->str:
    # убираем формулы, пустые строки, кириллицу
    t = latex_pattern.sub(' ', page_txt)
    t = re.sub(r'\s+', ' ', t)
    if non_eng_pattern.search(t):           # если страница русская — пропускаем
        return ""
    return t

def pdf_to_sentences(pdf_path:Path):
    doc = fitz.open(pdf_path)
    raw = " ".join(clean_text(p.get_text("text")) for p in doc)
    doc.close()
    # сегментация предложений
    sents = [s.strip() for s in sent_tokenize(raw) if len(s.split())>3]
    return sents

pdf_files = list(Path(cfg.pdf_dir).rglob("*.pdf"))
logger.info(f"Found {len(pdf_files)} pdfs")

sentences = []
for f in tqdm(pdf_files):
    sentences += pdf_to_sentences(f)
logger.info(f"Total sentences: {len(sentences):,}")


INFO:__main__:Found 0 pdfs
0it [00:00, ?it/s]
INFO:__main__:Total sentences: 0


In [7]:
# ШАГ 3. Словарь металлургических сущностей
gazetteer = {
    "MATERIAL":  ["steel","stainless steel","carbon steel","alloy","copper","aluminium",
                  "nickel","titanium","bronze","cast iron","iron","slag","billet","slab"],
    "EQUIPMENT": ["furnace","converter","ladle","rolling mill","caster","annealing line",
                  "blast furnace","basic oxygen furnace","electric arc furnace"],
    "PROCESS":   ["smelting","rolling","casting","annealing","forging","quenching",
                  "tempering","pickling","hot rolling","cold rolling","heat treatment"],
    "CHEMICAL":  ["carbon","manganese","chromium","silicon","phosphorus","sulfur","vanadium"],
    "STANDARD":  ["ASTM","EN","ISO","DIN","JIS","GOST"],
}

# быстрое обратное индексирование
flat2type = {v.lower():k for k,vs in gazetteer.items() for v in vs}
lex_sorted = sorted(flat2type, key=len, reverse=True)   # longest → shortest

tokenizer = AutoTokenizer.from_pretrained(cfg.model_ckpt)

def weak_label_sentence(sent):
    tokens = tokenizer.tokenize(sent)
    labels = ["O"]*len(tokens)

    sent_low = sent.lower()
    # находим все появления любого ключевого слова
    for kw in lex_sorted:
        for m in re.finditer(r'\b'+re.escape(kw)+r'\b', sent_low):
            # перевести char span → token span
            char_start, char_end = m.span()
            token_spans = tokenizer(sent, return_offsets_mapping=True,
                                    truncation=True, max_length=cfg.max_len).offset_mapping
            toks_in_span = [i for i,(s,e) in enumerate(token_spans)
                            if s>=char_start and e<=char_end]
            if not toks_in_span: continue
            ent_type = flat2type[kw]
            labels[toks_in_span[0]] = f"B-{ent_type}"
            for tid in toks_in_span[1:]:
                labels[tid] = f"I-{ent_type}"

    return tokens, labels

# генерируем разметку sentence→BIO
weak_data = []
for s in tqdm(sentences, desc="weak-label"):
    toks, labs = weak_label_sentence(s)
    if "B-" in " ".join(labs):              # отбрасываем пустые
        weak_data.append({"tokens":toks, "ner_tags":labs, "text":s})

logger.info(f"Weak-labelled sentences: {len(weak_data):,}")


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
weak-label: 0it [00:00, ?it/s]
INFO:__main__:Weak-labelled sentences: 0


In [8]:
# ШАГ 4. В datasets + разделение train/valid
unique_tags = sorted({t for ex in weak_data for t in ex["ner_tags"]})
tag2id = {t:i for i,t in enumerate(unique_tags)}

def encode_example(ex):
    ex["labels"] = [tag2id[t] for t in ex["ner_tags"]]
    return ex

ds = Dataset.from_list(weak_data).map(encode_example, remove_columns=[])
ds = ds.train_test_split(test_size=0.1, seed=42)
ds_dict = DatasetDict({"train":ds["train"], "validation":ds["test"]})
ds_dict = ds_dict.remove_columns(["ner_tags","text"])   # оставим tokens+labels
ds_dict = ds_dict.cast_column("labels", ClassLabel(names=unique_tags))


ValueError: Column name ['ner_tags', 'text'] not in the dataset. Current columns in the dataset: []

In [None]:
# ШАГ 5. Tokenize с выравниванием меток
def align_labels(batch):
    tokenized = tokenizer(batch["tokens"],
                          is_split_into_words=True,
                          truncation=True,
                          padding="max_length",
                          max_length=cfg.max_len)
    new_labels = []
    for i, word_ids in enumerate(tokenized.word_ids(batch_index=None)):
        word_ids = tokenized.word_ids(batch_index=i)
        sent_labels = batch["labels"][i]
        label_ids = []
        prev = None
        for wid in word_ids:
            if wid is None:
                label_ids.append(-100)
            elif wid!=prev:
                label_ids.append(sent_labels[wid])
            else:
                label_ids.append(sent_labels[wid] if sent_labels[wid]!=0 else -100)
            prev = wid
        new_labels.append(label_ids)
    tokenized["labels"] = new_labels
    return tokenized

ds_tok = ds_dict.map(align_labels, batched=True, remove_columns=["tokens"])


In [None]:
# ШАГ 6. Обучение
data_collator = DataCollatorForTokenClassification(tokenizer)
metric = classification_report

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    true, pred = [], []
    for p,l in zip(preds, labels):
        true_seq, pred_seq = [], []
        for pi, li in zip(p,l):
            if li!=-100:
                true_seq.append(unique_tags[li])
                pred_seq.append(unique_tags[pi])
        true.append(true_seq); pred.append(pred_seq)
    report = classification_report(true,pred,output_dict=True,zero_division=0)
    return {"f1": report["weighted avg"]["f1-score"]}

model_config = AutoConfig.from_pretrained(cfg.model_ckpt,
                                          num_labels=len(unique_tags),
                                          id2label={i:t for i,t in enumerate(unique_tags)},
                                          label2id=tag2id)
model = AutoModelForTokenClassification.from_pretrained(cfg.model_ckpt,
                                                        config=model_config)

args = TrainingArguments(
    output_dir="ner_matsci_checkpoint",
    learning_rate=cfg.lr,
    per_device_train_batch_size=cfg.batch_size,
    per_device_eval_batch_size=cfg.batch_size,
    num_train_epochs=cfg.num_epochs,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=50,
    fp16=torch.cuda.is_available(),
)

trainer = Trainer(model=model,
                  args=args,
                  train_dataset=ds_tok["train"],
                  eval_dataset=ds_tok["validation"],
                  data_collator=data_collator,
                  tokenizer=tokenizer,
                  compute_metrics=compute_metrics)

trainer.train()
trainer.save_model("ner_matsci_final")
tokenizer.save_pretrained("ner_matsci_final")


In [None]:
# ШАГ 7. Пайплайн NER с нашей дообученной моделью
ner_pipe = pipeline("ner",
                    model="ner_matsci_final",
                    tokenizer="ner_matsci_final",
                    aggregation_strategy="simple",
                    device=0 if torch.cuda.is_available() else -1)

def extract_entities_matsci(sentences):
    entities = []
    for sent in sentences:
        res = ner_pipe(sent["original"])
        for ent in res:
            ent["sentence_id"] = hash(sent["original"])
            ent["sentence"]    = sent["original"]
            entities.append(ent)
    return entities


In [None]:
class BERTEntityExtractor:
    def __init__(self, model_dir="ner_matsci_final"):
        self.pipe = pipeline("ner",
                             model=model_dir,
                             tokenizer=model_dir,
                             aggregation_strategy="simple",
                             device=0 if torch.cuda.is_available() else -1)

    def extract_entities(self, sentences):
        entities=[]
        for sent in sentences:
            try:
                res=self.pipe(sent["original"])
                for e in res:
                    e["sentence"]=sent["original"]
                    e["sentence_id"]=hash(sent["original"])
                entities.extend(res)
            except Exception as ex:
                logger.warning(f"NER failed: {ex}")
        return entities


In [None]:
pipeline = BaselineKnowledgeGraphPipeline()
result   = pipeline.process_pdf_corpus(cfg.pdf_dir)
json.dump(result, open("kg_results.json","w"), indent=2)
print("Done! KG JSON saved.")
