# [世界のUniversal Dependenciesと係り受けツール群](http://kanji.zinbun.kyoto-u.ac.jp/~yasuoka/publications/2021-06-22.pdf)
## 日本語UDを用いた係り受け解析器の自作
### Transformersとbert-large-japanese-char-extendedを用いる場合


必要なパッケージと訓練用train.conlluを準備

In [ ]:
!test -d UD_Japanese-GSD || git clone --depth=1 https://github.com/universaldependencies/UD_Japanese-GSD
!test -f train.conllu || ln -s UD_Japanese-GSD/ja_gsd-ud-train.conllu train.conllu
!pip install transformers datasets deplacy

train-c.conlluをtrain.conlluから作成

In [ ]:
from transformers import AutoTokenizer as AT
tkz = AT.from_pretrained("KoichiYasuoka/bert-large-japanese-char-extended")
with open("train.conllu", "r", encoding="utf-8") as f:
  r = f.read()
with open("train-c.conllu", "w", encoding="utf-8") as f:
  u,h = [],[0]
  for s in r.split("\n"):
    if s.startswith("# text = "):
      v,w = tkz.tokenize(s[9:]),s
    elif s > "" and not s.startswith("#"):
      t = s.split("\t")
      m = "_" if t[9].find("SpaceAfter=No") < 0 else "SpaceAfter=No"
      x,t[2],t[6],t[8],t[9] = t[1],"_",int(t[6]),"_","SpaceAfter=No"
      h.append(len(u)+1)
      while x > "":
        t[1],x = v[len(u)],x[len(v[len(u)].replace("##", "")):]
        u.append(list(t))
        t[3:8] = "X","_","_",int(t[0]),"goeswith"
      u[-1][9] = m
    elif s == "" and len(u) > 0:
      print(w, "\n".join("\t".join([str(i+1)]+t[1:6]+[str(h[t[6]])]+t[7:])
        for i,t in enumerate(u)), "", sep="\n", file=f)
      u,h = [],[0]

my.transを作成

In [ ]:
from transformers import (AutoTokenizer, AutoConfig,
  AutoModelForTokenClassification, DataCollatorForTokenClassification,
  TrainingArguments, Trainer)
from datasets.arrow_dataset import Dataset
brt = "KoichiYasuoka/bert-large-japanese-char-extended"
with open("train-c.conllu", "r", encoding="utf-8") as f:
  tok,tag = [],[]
  for s in f.read().strip().split("\n\n"):
    v = [t.split("\t") for t in s.split("\n") if not t.startswith("#")]
    tok.append([t[1] for t in v])
    tag.append(["\t".join([t[3], t[4], t[5], ("{:+}" if int(t[6]) else "0")
      .format(int(t[6])-int(t[0])), t[7]]) for t in v])
lid = {l:i for i,l in enumerate(set(sum(tag, [])))}
tkz = AutoTokenizer.from_pretrained(brt)
dts = Dataset.from_dict({"tokens": tok, "tags": tag,
  "input_ids": [tkz.convert_tokens_to_ids(s) for s in tok],
  "labels": [[lid[t] for t in s] for s in tag]})
cfg = AutoConfig.from_pretrained(brt, num_labels=len(lid), label2id=lid,
  id2label={i:l for l,i in lid.items()})
mdl = AutoModelForTokenClassification.from_pretrained(brt, config=cfg)
dcl = DataCollatorForTokenClassification(tokenizer=tkz)
arg = TrainingArguments(output_dir="/tmp", overwrite_output_dir=True,
  per_device_train_batch_size=4)
trn = Trainer(model=mdl, args=arg, data_collator=dcl, train_dataset=dts)
trn.train()
trn.save_model("my.trans")
tkz.save_pretrained("my.trans")

my.transで係り受け解析

In [ ]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
tkz = AutoTokenizer.from_pretrained("my.trans")
mdl = AutoModelForTokenClassification.from_pretrained("my.trans")
def nlp(sentence):
  s = tkz.tokenize(sentence)
  e = tkz.encode(s, return_tensors="pt", add_special_tokens=False)
  for i,q in enumerate(torch.argmax(mdl(e)[0], dim=2)[0].tolist()):
    t = [s[i],"_"]+mdl.config.id2label[q].split("\t")+["_","SpaceAfter=No"]
    s[i] = t[0:5]+[str(int(t[5])+i+1) if int(t[5]) else "0"]+t[6:]
  return "\n".join("\t".join([str(i+1)]+t) for i,t in enumerate(s))+"\n\n"
doc=nlp("虎穴に入らざれば虎子を得ず。")
print(doc)
import deplacy
deplacy.serve(doc,port=None)

goeswithを削り取る

In [ ]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
tkz = AutoTokenizer.from_pretrained("my.trans")
mdl = AutoModelForTokenClassification.from_pretrained("my.trans")
def nlp(sentence):
  s = tkz.tokenize(sentence)
  e = tkz.encode(s, return_tensors="pt", add_special_tokens=False)
  for i,q in enumerate(torch.argmax(mdl(e)[0], dim=2)[0].tolist()):
    t = [s[i],"_"]+mdl.config.id2label[q].split("\t")+["_","SpaceAfter=No"]
    s[i] = t[0:5]+[str(int(t[5])+i+1) if int(t[5]) else "0"]+t[6:]
  for i in [i for i in range(len(s)-1, 0, -1) if s[i][6] == "goeswith"]:
    t = s.pop(i)
    s[i-1][0] += t[0][2:] if t[0].startswith("##") else t[0]
    for t in [t for t in s if int(t[5]) > i]:
      t[5] = str(int(t[5])-1)
  return "\n".join("\t".join([str(i+1)]+t) for i,t in enumerate(s))+"\n\n"
doc=nlp("虎穴に入らざれば虎子を得ず。")
print(doc)
import deplacy
deplacy.serve(doc,port=None)