# Yomikata: Disambiguating Japanese Heteronyms

A step by step guide to training Yomikata's word disambiguation model.

# Word pronunciation lists

To clean the datasets we use it is useful to have a list of Japanese words and their pronunciations. 

We do that by parsing the unidic and sudachi dictionaries. Note these scripts are slow -- but run one time only.

In [None]:
from yomikata.dataset.unidic import unidic_data

unidic_data()

In [None]:
from yomikata.dataset.sudachi import sudachi_data

sudachi_data()

In [None]:
from yomikata.dataset.kanjidic import kanjidic_data

kanjidic_data()

In [None]:
from yomikata.dataset.pronunciations import pronunciation_data

pronunciation_data()

In [None]:
from pathlib import Path

import pandas as pd
from yomikata.config import config

df = pd.read_csv(Path(config.READING_DATA_DIR, "all.csv"))
df.sample(10)

# Corpuses of annotated sentences

The model is trained on sentences which already have furigana. We have four data sources which we process here. Note these scripts are slow -- but run one time only.

[Corpus of titles of works in the national diet library](https://github.com/ndl-lab/huriganacorpus-ndlbib)

In [None]:
# from yomikata.dataset.ndlbib import ndlbib_data

# ndlbib_data()

[Aozora Bunko book corpus](https://github.com/ndl-lab/huriganacorpus-aozora)

In [None]:
from yomikata.dataset.aozora import aozora_data

aozora_data()

[Kyoto University document leads corpus](https://github.com/ku-nlp/KWDLC)

In [None]:
from yomikata.dataset.kwdlc import kwdlc_data

kwdlc_data()

[Search result for our heterophones in the BCCWJ corpus](https://chunagon.ninjal.ac.jp/bccwj-nt/search)

In [None]:
from yomikata.dataset.bccwj import bccwj_data

bccwj_data()

In [None]:
from pathlib import Path
from yomikata.config import config, logger
from yomikata import utils

input_files = [
    Path(config.SENTENCE_DATA_DIR, "aozora.csv"),
    Path(config.SENTENCE_DATA_DIR, "kwdlc.csv"),
    Path(config.SENTENCE_DATA_DIR, "bccwj.csv"),
    # Path(config.SENTENCE_DATA_DIR, "ndlbib.csv"),
]

utils.merge_csvs(input_files, Path(config.SENTENCE_DATA_DIR, "all.csv"), n_header=1)
logger.info("✅ Merged sentence data!")

Filter out duplicate sentences.

In [None]:
from pathlib import Path
import pandas as pd

df = pd.read_csv(Path(config.SENTENCE_DATA_DIR, "all.csv"))
df_no_duplicates = df.drop_duplicates(subset=['sentence'], keep='first')
df_no_duplicates.to_csv(Path(config.SENTENCE_DATA_DIR, "all_filtered.csv"), index=False)
logger.info("✅ Filtered out duplicate sentences!")

# Spliting furigana in the corpus

First we generate a dictionary of representations of longer furigana in terms of shorter furigana that appear in the corpus. So for example `{引出/ひきだ}` will be broken down into `{引/ひ}` and `{出/きだ}`. The algorithm attempts to find a set of shorter furigana for which concatenation of surfaces and readings exactly matches the whole or at least the beginning of the longer furigana. It prefers more granular representations (larger amount of shorter furigana) and if two equally granular representations are possible, it picks the one which is composed of furigana with the largest combined frequency in the corpus. It also translates the "ー" character into specific hiragana representation.

In [None]:
from yomikata.dataset.breakdown import generate_breakdown_dictionary

generate_breakdown_dictionary()

Next we use this dictionary to replace all long furigana in the corpus with shorter furigana. By decomposing furigana we get a corpus that allows us to determine readings of kanji surfaces in a more granular way.

In [None]:
from pathlib import Path
from yomikata.dataset import split
from yomikata import utils
from yomikata.config import config, logger

split_dict = utils.load_dict(Path(config.BREAKDOWN_DATA_DIR, "translations.json"))
logger.info("Starting decomposition process, this may take a while...")
split.decompose_furigana(
    Path(config.SENTENCE_DATA_DIR, "all_filtered.csv"),
    Path(config.SENTENCE_DATA_DIR, "all_broken_down.csv"),
    split_dict,
)
logger.info("✅ Decomposed furigana!")

# Making a list of heteronyms

We use the list from [Sato et al 2022](https://aclanthology.org/2022.lrec-1.770.pdf) as a start. To these we add a list of additional heteronyms picked from the corpus by the frequency of mistakes MeCab tokenizer makes in predicting their readings and arrive at the following list:

In [None]:
# heteronyms = "国立|仮名|遺言|口腔|一途|最中|一行|一夜|下野|花弁|山陰|上下|世論|牧場|一味|施行|施工|転生|清浄|追従|墓石|漢書|作法|黒子|競売|開眼|求道|施業|借家|風車|背筋|逆手|生花|一寸|一分|一文|気骨|細目|船底|相乗|梅雨|風穴|夜話|野兎|冷水|翡翠|十八番|石綿|公文|読本|古本"
heteronyms = "年中|気味|束|影響|夏|父|理|宝塚|我|私|浩|置|彼|手|右|易|柱|次|是|白髪|文字|博士|造作|相|以|割|引|康|博|告|三重|割当|太鼓|期|近平|眼|許|弥|素|坂|安|綴|血|顎|献|手作|江|畑|七兵衛|野|壁|豊|弘法|生花|顎骨|織|梅|北|弘|対|宴|沢|法衣|突|粗|奇怪|心血|野兎|飛騨|塵|報|身体|島津|船|思|幸|先刻|俺|根|牛|新|力|稲|漢書|染|緒|目|春|与|粒|草子|礼拝|黄|舞鶴|電灯|娘|所|道標|掌|白|都|貝塚|岩|博文|缶詰|造|交換|七十|聞|蜂|枕|友|子|鋼|花|続|化|補綴|深|衣|手術|忍|胎|歯|高山|神|吉|馬鈴薯|土器|荷役|昨夜|止|延|細|受|鼻腔|妾|性骨|如何|兵法|脱水|山|箱|帖|雨|共|不足|風呂|陸奥|興|親|上野|廻|果|学|上手|鼻|針|両眼|藍|法師|府|出|共存|島|川|変換|河|鳥|尼|一夜|版|清浄|位|日暮|九十|風土|訴|今|雨水|白血|疱瘡|地|速|黒|文|群|竹|彩|直|台|橋|右衛門|蛙|処|例|宿|跡|涙|法華|毅|先|奥|星|追従|城|卵|鉄|華|太夫|掛|柳|両側|初|暦|不|躯|心|御|尾|格子|泊|今昔|潮|柳田|腔|市場|夜中|下|裕|塩|何分|合戦|建|河岸|身|浅|家|湖|夫|天|玉手|向|薄|郡|環|話|生命|板|笑|懐|藻|湯|車|瓦|考|道|函|老子|細々|一月|一日|郷|道程|細工|文科|此|脆|勢|園|含|眼鏡|守|妖|輪|婦|影|乳房|洋|艶|終|作|泰|犬|光|東|芽腫|見物|端|脂|鉤|三|刻|仁|色|船底|外|咽喉|下顎|特集|田|構|嫌気|顔|定|長|学校|並|肝|良|取|住居|舟|林|相撲|進|六|飛沫|見|同人|落|綴方|夜|負|法|繰|刃|片|数|居|己|一杯|明|伴|富|気質|最中|信夫|満|於|気|市|戸|旬|刑事|能登|一声|殺|折|五十|敏|一寸|名|公文|隆|判例|焼|故郷|青|立|祭|綱|小屋|河口|南|美|工学|陽子|宮|千里|別|年上|体|医学|存|明日|葉|粉|柄|頸|宗|桂|灰|使|日中|歩|科学|聖人|大手|枝|判|吹|施工|強|宿主|法学|好|膝|介|歩兵|一昨年|強力|早|三国|八幡|尺|定家|前|鶏|一味|酔|分別|難治|氷|小学|前駆|裂|風|半月|分間|麗|膠|竜馬|細目|入|栄|児|粋|兵衛|節|管|風車|頭数|雲|露|見透|月|逆手|香|振|山城|雪|花弁|中|油|兄|雑|経緯|古本|合法|大蔵|秋|氏|日向|下手|討|中間|七|哲|魂|表|事|誠|日本|米|問屋|上|清水|高野|牧|物|口|土|至|正|芽|寛|孫|石|助|恋|等|墓石|流|越|桐|重|敬|何|起|時計|命|際|海|化学|太|酒|床|琴|甘藷|枯|声|炎|山河|器|連中|皮|一|様|伝|紀|開眼|宝|銀杏|勝|境|砂|大山|性|虫|側|有|骨|歌|室|時|耳|駿河|間|北方|玩具|元|二十|丈|万|乳|送|衛門|穂|昭|零|哲郎|調|面|底力|通|善|開発|会津|水面|硝子|昨日|緑色|婆|盛土|言|合|編|墨|漁場|陰|源|咳|縁|一行|英文|明清|二重|着|来|筆|借家|信|張|一時|誰|異|静|依存|血症|末|法典|岳|当|電場|梅雨|探|打|墳|相乗|翡翠|望|上顎|魚|荷|語|抱|馳|極|清|巌|聖|技|森|侍|球|女|羽|坊|教|菖蒲|徹|上方|往|彦|緑|候|三角|固|幼|仏|及|下野|宅|武|遺言|九|大勢|福島|翼|黒子|復|緑化|手続|孝|民|輝|赤血|病|係|分|一昨日|母|内|報告|暮|人|世|鬼|決|大和|真|久|勇|兵馬|他|小判|度|堤|厚|嫂|今日|登|小六|古|種|明代|巣|其|火|一言|宏|年|皆|君|剛|雅|花崗|変化|吾妻|赤|袋|里|余|港|淳|獅子|呉|冷水|所謂|鑑|金|鋼板|発足|常|転生|草|疾風|辺|池|墓|巻|綿|小形|角|格付|十八番|治|糸|布|街|観|紙|水|恵|愛|傍|朝|貫|無|部屋|村|日|国立|古今|桜|黄色|修|小|後|額|酒類|指|空|泉|狼|要|貝|四十|仔|薬|広|隠岐|背|四|研究|一途|玉|童|武蔵|石巻|刀|頭蓋|音|噺|本|拍子|公|寺|動力|類聚|殿|館|足跡|鍼|腹|画|達|匹|書|毛|駕|出展|偽|上下|為|実|男|燃|場|教化|姉|歪|鏡|胸|印|附|働|眸|寒気|西郷|司|菓子|程|気骨|世論|末期|人妻|谷間|草紙|寿|歳|基|大社|横|組|山村|灯|本書|志|悪|伸子|求道|底|心肺|高|蔵|戦|大人|会|馬|落葉|雄|頃|訳|競売|人気|茂|二人|町|悲|原|之|平|修業|大分|秘|史学|木|杯|佐|城跡|仮名|夫婦|抜|問題|二|峰|主|子規|紅葉|彼方|空力|行|白石|熱力|貧|付|動学|一文|明後日|手指|因|手塚|者|風穴|平野|浮|孔|譜|大事|乾|楽|奴|留|創|陽|山陰|生|胎仔|国|三千|紅|独|路|足|倉|品|読|吾|包|秦|沼津|動向|徒然草|方|栄三郎|動静|経|聖徳|日間|施業|保|発|筋|房|裏|頭|沢庵|増|銭|芳|夜話|如|根本|口腔|利益|店|網|嚥下|妻|百|活|権|札|何時|現世|読本|型|大家|十|代|谷|文書|麻|業|形|作法|得|町家|貴女|陰陽|木質|茶道|豚|蚕|帯|千|一方|冬|浪漫|邦|波|心中|味|便|高村|牧場|詩|切|洲|石綿|夢|俊|燕|幻|棟|敷|梁|生物|根治|金色|背筋|大|塚|雷|関|残存|竜|熱|樹|翁|冠|施行|防錆|一目|捧|左|八|問|西|丁|大谷|小倉|草地|笠|答|文学|一分|播"

We look in our sentence data for these known heteronyms

In [None]:
import pandas as pd
from pathlib import Path
from yomikata.config import config, logger

full_df = pd.read_csv(Path(config.SENTENCE_DATA_DIR, "all_broken_down.csv"))
len(full_df)

In [None]:
%%time
df = full_df[
    full_df["sentence"].str.contains(heteronyms)
]
len(df)

In [None]:
from yomikata import utils
from collections import Counter
import pandas as pd
from pathlib import Path
from yomikata.config import config, logger
import random

heteronym_dict = {}
dictionary_df = pd.read_csv(Path(config.READING_DATA_DIR, "all.csv"))
dictionary_set = set(dictionary_df.itertuples(index=False, name=None))

# for heteronym in ["有"]:
for heteronym in heteronyms.split("|"):
    furis = df.loc[df["sentence"].str.contains(heteronym), "furigana"].values
    readings = []
    for furi in furis:
        reading_list = utils.get_all_surface_readings(heteronym, furi)
        readings += reading_list
        # readings += [string for string in reading_list if "ー" not in string and (heteronym, string) in dictionary_set]
    ms = Counter(readings)
    ms = {k: v for k, v in sorted(ms.items(), key=lambda item: item[1], reverse=True)}
    print(heteronym)
    print(ms)
    heteronym_dict[heteronym] = ms

We give up on identifying readings for which we have less than 40 examples

In [None]:
ncut = 40
heteronym_dict_cut = {
    k: {k2: v2 for (k2, v2) in v.items() if v2 > ncut}
    for (k, v) in heteronym_dict.items()
}
heteronym_dict_cut = {k: v for (k, v) in heteronym_dict_cut.items() if len(v) > 1}
print(len(heteronym_dict_cut))
heteronym_dict_cut

In [None]:
utils.save_dict(heteronym_dict_cut, Path(config.CONFIG_DIR, "heteronyms.json"))

# Prepare augmented dataset

In [None]:
from pathlib import Path
from yomikata.config import config, logger
from yomikata import utils

input_files = [
    # Path(config.SENTENCE_DATA_DIR, "aozora.csv"),
    # Path(config.SENTENCE_DATA_DIR, "kwdlc.csv"),
    # Path(config.SENTENCE_DATA_DIR, "bccwj.csv"),
    Path(config.SENTENCE_DATA_DIR, "ndlbib.csv"),
]

utils.merge_csvs(input_files, Path(config.SENTENCE_DATA_DIR, "augmentation.csv"), n_header=1)
logger.info("✅ Merged sentence data!")

In [None]:
from pathlib import Path
import pandas as pd

df = pd.read_csv(Path(config.SENTENCE_DATA_DIR, "augmentation.csv"))
df_no_duplicates = df.drop_duplicates(subset=['sentence'], keep='first')
df_no_duplicates.to_csv(Path(config.SENTENCE_DATA_DIR, "augmentation_filtered.csv"), index=False)
logger.info("✅ Filtered out duplicate sentences!")

In [None]:
from pathlib import Path
import pandas as pd

augmentation_filtered_path = Path(config.SENTENCE_DATA_DIR, "augmentation_filtered.csv")
test_optimized_path = Path(config.SENTENCE_DATA_DIR, "test/test_optimized_strict_heteronyms.csv")
augmentation_filtered_cleaned_path = Path(config.SENTENCE_DATA_DIR, "augmentation_filtered_cleaned.csv")

df_augmented = pd.read_csv(augmentation_filtered_path)
df_test_optimized = pd.read_csv(test_optimized_path)

initial_row_count = len(df_augmented)

df_cleaned = df_augmented[~df_augmented['sentence'].isin(df_test_optimized['sentence'])]

final_row_count = len(df_cleaned)

logger.info(f"🔢 Number of rows removed: {initial_row_count - final_row_count}")

df_cleaned.to_csv(augmentation_filtered_cleaned_path, index=False)

logger.info("✅ Cleaned augmentation_filtered.csv and saved to augmentation_filtered_cleaned.csv!")

In [None]:
from pathlib import Path

from yomikata.config import config, logger
from yomikata.dataset.split import (
    check_data,
    filter_dictionary,
    filter_simple,
    optimize_furigana,
    remove_other_readings,
    split_data,
)

logger.info("Rough filtering for sentences with heteronyms")
filter_simple(
    Path(config.SENTENCE_DATA_DIR, "augmentation_filtered_cleaned.csv"),
    Path(config.SENTENCE_DATA_DIR, "augmentation_filtered_cleaned_have_heteronyms.csv"),
    config.HETERONYMS.keys(),
)

In [None]:
from pathlib import Path
from yomikata.dataset import split
from yomikata import utils
from yomikata.config import config, logger

split_dict = utils.load_dict(Path(config.BREAKDOWN_DATA_DIR, "translations.json"))
logger.info("Starting decomposition process, this may take a while...")
split.decompose_furigana(
    Path(config.SENTENCE_DATA_DIR, "augmentation_filtered_cleaned_have_heteronyms.csv"),
    Path(config.SENTENCE_DATA_DIR, "augmentation_filtered_cleaned_have_heteronyms_broken_down.csv"),
    split_dict,
)
logger.info("✅ Decomposed furigana!")

In [None]:
from pathlib import Path

from yomikata.config import config, logger
from yomikata.dataset.split import (
    check_data,
    filter_dictionary,
    filter_simple,
    optimize_furigana,
    remove_other_readings,
    split_data,
)
logger.info("Removing heteronyms with unexpected readings")
remove_other_readings(
    Path(config.SENTENCE_DATA_DIR, "augmentation_filtered_cleaned_have_heteronyms_broken_down.csv"),
    Path(config.SENTENCE_DATA_DIR, "augmentation_filtered_cleaned_have_heteronyms_broken_down_strict.csv"),
    config.HETERONYMS,
)

In [None]:
from pathlib import Path
from yomikata.config import config, logger
from yomikata import utils
import pandas as pd

df1 = pd.read_csv(Path(config.SENTENCE_DATA_DIR, "augmentation_filtered_cleaned_have_heteronyms_broken_down_strict.csv"))
df1 = df1[['sentence', 'furigana']]
temp_file = Path(config.SENTENCE_DATA_DIR, "temp_filtered.csv")
df1.to_csv(temp_file, index=False)

input_files = [
    temp_file,
    Path(config.SENTENCE_DATA_DIR, "train/train_optimized_strict_heteronyms.csv"),
]
utils.merge_csvs(input_files, Path(config.SENTENCE_DATA_DIR, "train/train_optimized_strict_heteronyms_augmented.csv"), n_header=1)

temp_file.unlink()

# Process and split data

In [None]:
from pathlib import Path

from yomikata.config import config, logger
from yomikata.dataset.split import (
    check_data,
    filter_dictionary,
    filter_simple,
    optimize_furigana,
    remove_other_readings,
    split_data,
)
from yomikata.dictionary import Dictionary

We extract from the dataset the sentences which include our heteronyms.

In [None]:
logger.info("Rough filtering for sentences with heteronyms")
filter_simple(
    Path(config.SENTENCE_DATA_DIR, "all_broken_down.csv"),
    Path(config.SENTENCE_DATA_DIR, "have_heteronyms_simple.csv"),
    config.HETERONYMS.keys(),
)

In [None]:
logger.info("Use sudachi to filter out heteronyms in known compounds")
filter_dictionary(
    Path(config.SENTENCE_DATA_DIR, "have_heteronyms_simple.csv"),
    Path(config.SENTENCE_DATA_DIR, "have_heteronyms_simple.csv"),
    config.HETERONYMS.keys(),
    Dictionary("sudachi"),
)

Finally we remove sentences that only include heteronyms with readings that we are not trying to predict for.

In [None]:
logger.info("Removing heteronyms with unexpected readings")
remove_other_readings(
    Path(config.SENTENCE_DATA_DIR, "have_heteronyms_simple.csv"),
    Path(config.SENTENCE_DATA_DIR, "optimized_strict_heteronyms.csv"),
    config.HETERONYMS,
)

After checking our data makes sense we do a train/val/test split

In [None]:
test_result = check_data(
    Path(config.SENTENCE_DATA_DIR, "optimized_strict_heteronyms.csv")
)
logger.info("Performing train/test/split")
split_data(Path(config.SENTENCE_DATA_DIR, "optimized_strict_heteronyms.csv"))

logger.info("Data splits successfully generated!")

# DBERT

We train a BERT classifier model to disambiguate the heteronyms in our data. 

## Dataset Info

Before we start training we do some simple tests using the BERT tokenizer on the dataset

In [None]:
from pathlib import Path

from yomikata.config import config, logger
from datasets import load_dataset

dataset = load_dataset(
    "csv",
    data_files={
        "train": str(Path(config.TRAIN_DATA_DIR, "train_optimized_strict_heteronyms.csv")),
        "val": str(Path(config.VAL_DATA_DIR, "val_optimized_strict_heteronyms.csv")),
        "test": str(Path(config.TEST_DATA_DIR, "test_optimized_strict_heteronyms.csv")),
    },
)
from yomikata.dbert import dBert

reader = dBert()

dataset = dataset.map(
    reader.batch_preprocess_function, batched=True, fn_kwargs={"pad": False}
)
dataset = dataset.filter(
    lambda entry: any(label != -100 for label in entry["labels"])
)

In [None]:
import numpy as np
from collections import Counter
from tqdm import tqdm

labels = []
for key in dataset.keys():
    print(f"{key} dataset has {len(dataset[key])} members")
    have_labels = [i for i in dataset[key] if np.max(i["labels"]) != -100]
    print(f"{len(have_labels)} actually contain heteronyms")
    key_length = len(dataset[key])
    for i in tqdm(range(key_length), desc="Counting labels"):
        labels += [value for value in dataset[key][i]["labels"] if value != -100]
    print("--")

label_counter = Counter(labels)

In [None]:
from collections import defaultdict
heteronyms = defaultdict(dict)

for label in label_counter:
    label_class = reader.label_encoder.index_to_class[label]
    (surface, reading) = label_class.split(":")
    heteronyms[surface][reading] = label_counter[label]

for heteronym in reader.heteronyms:
    print("heteronym:", heteronym)
    total = 0
    for reading in heteronyms[heteronym]:
        print(reading, heteronyms[heteronym][reading])
        total += heteronyms[heteronym][reading]
    print("total:", total)
    print("------------------------------")


## Train 

To train the model in the notebook

In [None]:
from yomikata.dbert import dBert
from datasets import load_dataset
from yomikata.config import config, logger
from pathlib import Path

reader = dBert(reinitialize=True)

dataset = load_dataset(
    "csv",
    data_files={
        "train": str(Path(config.TRAIN_DATA_DIR, "train_optimized_strict_heteronyms.csv")),
        "val": str(Path(config.VAL_DATA_DIR, "val_optimized_strict_heteronyms.csv")),
        "test": str(Path(config.TEST_DATA_DIR, "test_optimized_strict_heteronyms.csv")),
    },
)

reader.train(dataset)

Or using to get MLflow integration, experiment tracking, metrics, run the following in command line:

```
source yomikata/venv/bin/activate

python yomikata/yomikata/main.py yomikata/config/dbert-train-args.json
```

## Use 

In [None]:
from pathlib import Path

from yomikata.config import config, logger
from yomikata.dbert import dBert
# from yomikata.main import get_artifacts_dir_from_run

# artifacts_dir = get_artifacts_dir_from_run("e392694b345e4ca19fd97f6a872ced98")
# reader = dBert(artifacts_dir)
reader = dBert()

from yomikata.dictionary import Dictionary

dictreader = Dictionary()

In [None]:
text = "知って備える新型インフルエンザ職場・家庭で今日からすべきこと"  # 知[し]って備[そな]える新型[しんがた]インフルエンザ職場[しょくば]・家庭[かてい]で今日[きょう]からすべきこと
print(dictreader.furigana(reader.furigana(text)))
print(dictreader.furigana(text))

In [None]:
text = "身体--我々自身がそれであるところの自然"  # 身体[しんたい]--我々[われわれ]自身[じしん]がそれであるところの自然[しぜん]
print(dictreader.furigana(reader.furigana(text)))
print(dictreader.furigana(text))

In [None]:
text = "気がついたものかそれとも偶然からか、狙われた団七がふと首をすくめたので、危うく鉄扇がその身体の上を通り越しながら、丁度並行して大坪流の秘術をつくしつつあった右側向うの、黒住団七ならぬ古高新兵衛の脇腹に、はッしと命中いたしました。"  # ,気[き]がついたものかそれとも偶然[ぐうぜん]からか、狙[ねら]われた団七[だんしち]がふと首[くび]をすくめたので、危[あや]うく鉄扇[てっせん]がその身体[からだ]の上[うえ]を通[とお]り越[こ]しながら、丁度[ちょうど]並行[へいこう]して大坪流[おおつぼりゅう]の秘術[ひじゅつ]をつくしつつあった右側[みぎがわ]向[むこ]うの、黒住[くろずみ]団七[だんしち]ならぬ古高[ふるたか]新兵衛[しんべえ]の脇腹[わきばら]に、はッしと命中[めいちゅう]いたしました。
print(dictreader.furigana(reader.furigana(text)))
print(dictreader.furigana(text))

In [None]:
text = "Interview雲田はるこ:BLから『昭和元禄落語心中』まで人間の個性を見つめる稀代の描き手"  # ,Interview雲田[うんでん]はるこ:BLから『昭和[しょうわ]元禄[げんろく]落語[らくご]心中[しんじゅう]』まで人間[にんげん]の個性[こせい]を見[み]つめる稀代[きたい]の描[えが]き手[て]
reader.furigana(text)

In [None]:
text = "特集生成する身体"  # ,特集[とくしゅう]生成[せいせい]する身体[しんたい]
print(dictreader.furigana(reader.furigana(text)))
print(dictreader.furigana(text))

In [None]:
text = "今日の世界情勢は"
print(dictreader.furigana(reader.furigana(text)))
print(dictreader.furigana(text))

In [None]:
text = "あの力士には金星はどれぐらいある？"
print(dictreader.furigana(reader.furigana(text)))
print(dictreader.furigana(text))

In [None]:
text = "黄色と黒の組み合わせは、危険であることを表す"
reader.furigana(text)  # 表　is in but 表す　is properly parsed as not ambiguous

In [None]:
text = "表参道に行きます"
reader.furigana(
    text
)  # 表　is in but since is in the compound 表参道　 it is properly recognized as something that should be looked up in a dictionary

In [None]:
text = "その表を見せてください"
reader.furigana(text)  # Correct

In [None]:
text = "あの家の表は綺麗です"
reader.furigana(text)  # Correct

In [None]:
text = "建築表を見せてください"
reader.furigana(text)  # Failed?

# Code structure

In [None]:
from yomikata import utils
from yomikata.dbert import dBert

In [None]:
reader = dBert()

In [None]:
test_sentence = 'そして、{畳/たたみ}の{表/おもて}は、すでに{幾/いく}{年/ねん}{前/まえ}に{換/か}えられたのか{分/わか}らなかった'

In [None]:
%time
disambiguated_sentence = reader.furigana(utils.remove_furigana(test_sentence))
print(disambiguated_sentence)

In [None]:
from yomikata.dictionary import Dictionary

dictreader = Dictionary()
dictreader.furigana(utils.remove_furigana(test_sentence))

In [None]:
dictreader.furigana(disambiguated_sentence)

In [None]:
dictreader.furigana(disambiguated_sentence) == dictreader.furigana(
    utils.remove_furigana(test_sentence)
)

## Test on datasets

In [None]:
from pathlib import Path

import torch
from yomikata.config import config
from yomikata.dbert import dBert
from yomikata.main import get_artifacts_dir_from_run

# artifacts_dir = get_artifacts_dir_from_run("e392694b345e4ca19fd97f6a872ced98")
# artifacts_dir = Path(
#    get_artifacts_dir_from_run("4d19dfb0d0b64b518d8e5506e3f6a726"), "checkpoint-10200"
# )

reader = dBert()

In [None]:
from datasets import load_dataset

dataset = load_dataset(
    "csv",
    data_files={
        "train": str(
            Path(config.TRAIN_DATA_DIR, "train_optimized_strict_heteronyms.csv")
        ),
        "val": str(Path(config.VAL_DATA_DIR, "val_optimized_strict_heteronyms.csv")),
        "test": str(Path(config.TEST_DATA_DIR, "test_optimized_strict_heteronyms.csv")),
    },
)

dataset = dataset.map(
    reader.batch_preprocess_function, batched=True, fn_kwargs={"pad": False}
)
dataset = dataset.filter(
    lambda entry: any(label != -100 for label in entry["labels"])
)

In [None]:
from transformers import Trainer
from yomikata.custom_bert import CustomDataCollatorForTokenClassification
import evaluate

data_collator = CustomDataCollatorForTokenClassification(
    tokenizer=reader.tokenizer, padding=True
)

accuracy_metric = evaluate.load("accuracy")
recall_metric = evaluate.load("recall")

def compute_metrics(p):
    predictions, labels = p  # predictions are already the argmax of logits
    true_predictions = [pred for prediction, label in zip(predictions, labels) for pred, lab in zip(prediction, label) if lab != -100]
    true_labels = [lab for prediction, label in zip(predictions, labels) for pred, lab in zip(prediction, label) if lab != -100]
    return {"accuracy": accuracy_metric.compute(references=true_labels, predictions=true_predictions)["accuracy"], "recall": recall_metric.compute(references=true_labels, predictions=true_predictions, average="macro", zero_division=0)["recall"]}

trainer = Trainer(
    model=reader.model,
    tokenizer=reader.tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=lambda logits, _: torch.argmax(logits, dim=-1)
)

In [None]:
%%time
import numpy as np
from yomikata.config import logger

reader.model.eval()
full_performance = {}
# for key in dataset.keys():
for key in ["test"]:
    max_evals = min(1000000, len(dataset[key]))
    # max_evals = len(dataset[key])
    logger.info(f"getting predictions for {key}")
    subset = dataset[key].shuffle().select(range(max_evals))
    prediction_output = trainer.predict(subset)
    logger.info(f"processing predictions for {key}")
    metrics = prediction_output[2]
    labels = prediction_output[1]

    logger.info("processing performance")
    performance = {
        heteronym: {
            "n": 0,
            "readings": {
                reading: {
                    "n": 0,
                    "found": {readingprime: 0 for readingprime in list(reader.heteronyms[heteronym].keys())}
                }
                for reading in list(reader.heteronyms[heteronym].keys())
            },
        }
        for heteronym in reader.heteronyms.keys()
    }

    flattened_logits = [
        logit
        for sequence_logits, sequence_labels in zip(prediction_output[0], labels)
        for (logit, l) in zip(sequence_logits, sequence_labels) if l != -100
    ] # this is already argmaxed in preprocess_logits_for_metrics, so the resulting list is 1d. valid_mask processing in CustomBertForTokenClassification.forward takes care of zeoring out irrelevant logits

    true_labels = [
        str(reader.label_encoder.index_to_class[l])
        for label in labels
        for l in label if l != -100
    ]

    for i, true_label in enumerate(true_labels):
        (true_surface, true_reading) = true_label.split(":")
        performance[true_surface]["n"] += 1
        performance[true_surface]["readings"][true_reading]["n"] += 1
        predicted_label = reader.label_encoder.index_to_class[flattened_logits[i]]
        predicted_reading = predicted_label.split(":")[1]
        performance[true_surface]["readings"][true_reading]["found"][predicted_reading] += 1

    for surface in performance:
        for true_reading in performance[surface]["readings"]:
            true_count = performance[surface]["readings"][true_reading]["n"]
            predicted_count = performance[surface]["readings"][true_reading]["found"][true_reading]
            performance[surface]["readings"][true_reading]["accuracy"] = predicted_count / true_count if true_count > 0 else "NaN"
        correct_count = sum(performance[surface]["readings"][true_reading]["found"][true_reading] for true_reading in performance[surface]["readings"])
        all_count = performance[surface]["n"]
        performance[surface]["accuracy"] = correct_count / all_count if all_count > 0 else "NaN"

    performance = {
        "metrics": metrics,
        "heteronym_performance": performance,
    }

    full_performance[key] = performance

full_performance

# Performance for dictionary 

In [None]:
from pathlib import Path

import numpy as np
import pandas as pd
from yomikata.config import config, logger
from speach.ttlig import RubyFrag, RubyToken
from yomikata import utils
from yomikata.dictionary import Dictionary
from yomikata.dataset import breakdown
from yomikata.dataset.split import replace_furigana

reader = Dictionary("sudachi")
heteronyms = config.HETERONYMS

In [None]:
filename = Path(config.TEST_DATA_DIR, "test_optimized_strict_heteronyms.csv")
max_evals = 1000000
df = pd.read_csv(filename, header=0)
df = df.sample(frac=1)
if max_evals is not None:
    max_evals = max(max_evals, 1)
    max_evals = min(max_evals, len(df))
    df = df.head(max_evals)

df["furigana_found"] = df.apply(
    lambda x: reader.furigana(utils.standardize_text(x["sentence"])), axis=1
)

sentences = df["furigana_found"].tolist()
sentences += df["furigana"].tolist()
(split_dict, no_translation) = breakdown.sentence_list_to_breakdown_dictionary(sentences)

In [None]:
df["furigana_found"] = df["furigana_found"].apply(
    lambda s: replace_furigana(s, split_dict)
)

In [None]:
from tqdm import tqdm
performance = {
    heteronym: {
        "n": 0,
        "readings": {
            reading: {
                "n": 0,
                "found": {
                    readingprime: 0
                    for readingprime in list(heteronyms[heteronym].keys()) + ["<OTHER>"]
                },
            }
            for reading in list(heteronyms[heteronym].keys())
        },
    }
    for heteronym in heteronyms.keys()
}
failures = 0
for i, row in tqdm(df.iterrows(), total=df.shape[0], desc="processing performance"):
    matches = utils.find_all_substrings(row["sentence"], heteronyms.keys())
    furis_true = utils.get_furis(row["furigana"])
    furis_found = utils.get_furis(row["furigana_found"])
    failure = False
    for location in matches:
        surface = matches[location]
        reading_true = utils.get_reading_from_furi(location, len(surface), furis_true)
        if not reading_true:
            continue
        reading_found = utils.get_reading_from_furi(location, len(surface), furis_found)
        if not reading_found:
#            print(location, surface, row["furigana"], row["furigana_found"], row["sentence"])
            failure = True
        performance[surface]["n"] += 1
        if (reading_true in performance[surface]["readings"].keys()):
            found_reading = reading_found if reading_found in performance[surface]["readings"].keys() else "<OTHER>"
            performance[surface]["readings"][reading_true]["n"] += 1
            performance[surface]["readings"][reading_true]["found"][found_reading] += 1
    if failure:
        failures += 1
n = 0
correct = 0
for surface in performance.keys():
    for true_reading in performance[surface]["readings"].keys():
        performance[surface]["readings"][true_reading]["accuracy"] = np.round(
            performance[surface]["readings"][true_reading]["found"][true_reading]
            / np.array(performance[surface]["readings"][true_reading]["n"]),
            3,
        )

    performance[surface]["accuracy"] = np.round(
        sum(
            performance[surface]["readings"][true_reading]["found"][true_reading]
            for true_reading in performance[surface]["readings"].keys()
        )
        / np.array(performance[surface]["n"]),
        3,
    )

    correct += sum(
        performance[surface]["readings"][true_reading]["found"][true_reading]
        for true_reading in performance[surface]["readings"].keys()
    )
    n += performance[surface]["n"]

In [None]:
print(failures, len(df))

In [None]:
print("Total accuracy:", correct/n)

In [None]:
print({key: performance[key]["accuracy"] for key in performance.keys()})

In [None]:
performance

# Details of classifying based on textual embeddings

With the T5 model I am fine-tuning the whole encoder-decoder architecture to encode the embeddings and then output the correct readings for every token. This assumes essentially that every token can be ambiguous and can have any possible reading.

The Amazon paper does something simpler. It takes the BERT encodings as input and for ambiguous tokens trains a small classifier model to choose between 2 or 3 readings. Is this kind of thing possible for Japanese? Let's look at some tokenizations and see if such a thing is possible for japanese.

## Proof of concept: Do contextual embeddings significantly differ for heteronyms?

In [None]:
word = "金星"

In [None]:
text1 = "金星は太陽系で太陽に近い方から2番目の惑星。"
text2 = "金星とは、大相撲で、平幕の力士が横綱と取組をして勝利することである。"
texts = [text1, text2]

In [None]:
from yomikata.dictionary import DictionaryReader

DicReader = DictionaryReader()
for text in texts:
    print(DicReader.tagger(text))

In [None]:
# Based on tokenizer results below 平幕 appears to be in unidic but not unidic_lite

In [None]:
from transformers import BertJapaneseTokenizer

tokenizer = BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-v2")

In [None]:
%time
for text in texts:
    text_encoded = tokenizer(
        text,
        add_special_tokens=False,
    )
    input_ids = text_encoded["input_ids"]
    input_mask = text_encoded["attention_mask"]
    print(input_ids)
    print([tokenizer._convert_id_to_token(input_id) for input_id in input_ids])
    tokenizer.decode(input_ids)

In [None]:
from transformers import BertModel

model = BertModel.from_pretrained("cl-tohoku/bert-base-japanese-v2")
model.eval();

In [None]:
for text in texts:
    text_encoded = tokenizer(
        text,
        max_length=16,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
        add_special_tokens=False,
    )  # needs to be pytorch tensors
    input_ids = text_encoded["input_ids"]
    input_mask = text_encoded["attention_mask"]

    print(input_ids.shape)

    outputs = model.forward(input_ids=input_ids, attention_mask=input_mask)

    print(outputs.last_hidden_state)
    print(outputs.last_hidden_state.shape)

## Embedding visualization

In [None]:
from transformers import BertJapaneseTokenizer, BertModel

tokenizer = BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-v2")
model = BertModel.from_pretrained("cl-tohoku/bert-base-japanese-v2")
model.eval()
import numpy as np

words = np.array(list(tokenizer.vocab.keys()))
wordembs = model.embeddings.word_embeddings.weight

In [None]:
print(wordembs.shape)  # 32768 is the vocab size and 768 the embedding dimension

In [None]:
wordembs = wordembs.detach().numpy()

In [None]:
# Determine vocabulary to use for t-SNE/visualization. The indices are hard-coded based partially on inspection:
char_indices_to_use = np.arange(851, 1063, 1)
voc_indices_to_plot = np.append(char_indices_to_use, np.arange(23000, 27000, 1))
voc_indices_to_use = np.append(char_indices_to_use, np.arange(17000, 27000, 1))

In [None]:
print(len(voc_indices_to_plot))
print(len(voc_indices_to_use))

In [None]:
# list(words[bert_voc_indices_to_use])

In [None]:
wordembs_to_use = wordembs[voc_indices_to_use]

In [None]:
from sklearn.manifold import TSNE

# Run t-SNE on the BERT vocabulary embeddings we selected:
mytsne_words = TSNE(n_components=2, early_exaggeration=12, metric="cosine", init="pca")
wordembs_to_use_tsne = mytsne_words.fit_transform(wordembs_to_use)

In [None]:
wordembs_to_use.shape

In [None]:
wordembs_to_use

In [None]:
words_to_plot = words[voc_indices_to_plot]
print(len(words_to_plot))

In [None]:
# Plot the transformed BERT vocabulary embeddings:
import japanize_matplotlib
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "VL Gothic"

fig = plt.figure(figsize=(100, 60))
alltexts = list()
for i, txt in enumerate(words_to_plot):
    plt.scatter(wordembs_to_use_tsne[i, 0], wordembs_to_use_tsne[i, 1], s=0)
    currtext = plt.text(wordembs_to_use_tsne[i, 0], wordembs_to_use_tsne[i, 1], txt)
    alltexts.append(currtext)


# Save the plot before adjusting.
plt.savefig("japanese-viz-bert-voc-noadj.pdf", format="pdf")
# print('now running adjust_text')
# Using autoalign often works better in my experience, but it can be very slow for this case, so it's false by default below:
# numiters = adjust_text(alltexts, autoalign=True, lim=50)
# from adjustText import adjust_text
# numiters = adjust_text(alltexts, autoalign=False, lim=50)
# print('done adjust text, num iterations: ', numiters)
# plt.savefig('japanese-viz-bert-voc-tsne10k-viz4k-adj50.pdf', format='pdf')

# plt.show()

In [None]:
### フォント一覧を確認するサンプルコード
# import matplotlib.pyplot as plt
# import matplotlib.font_manager as fm
# import numpy as np

# fonts = list(np.unique([f.name for f in matplotlib.font_manager.fontManager.ttflist]))

# fig = plt.figure(figsize=(8, 100))
# ax = fig.add_subplot(1, 1, 1)
# ax.set_ylim([-1, len(fonts)])
# ax.set_yticks(np.arange(0, len(fonts), 10))

# for i, f in enumerate(fonts):
#     ax.text(0.2, i,  '日本語強 {}'.format(f), fontdict={'family': f, 'fontsize': 14})

# plt.show()

In [None]:
from pathlib import Path

import pandas as pd
from yomikata.config import config, logger

df = pd.read_csv(Path(config.SENTENCE_DATA_DIR, "aozora.csv"))

In [None]:
word = "市場"
word_classes = ["しじょう", "いちば"]
word = "礼拝"
word_classes = ["れいはい", "らいはい"]
word = "今日"
word_classes = ["きょう", "こんにち"]
word = "今日"
word_classes = ["きょう", "こんにち"]
word = "表"
word_classes = ["ひょう", "おもて"]
word = "仮名"
word_classes = ["かな", "かめい"]
word = "変化"
word_classes = ["へんか", "へんげ"]

In [None]:
from yomikata.heteronyms import heteronyms

print(heteronyms[heteronyms["surface"] == word])

from pathlib import Path

pronunciation_df = pd.read_csv(Path(config.PRONUNCIATION_DATA_DIR, "all.csv"))
print(pronunciation_df[pronunciation_df["surface"] == word]["pronunciations"].values)

In [None]:
df_keyword = df[df["sentence"].str.contains(word)]
df_keyword = df_keyword.reset_index(drop=True)
window_size = 128
df_keyword["sentence-shorter"] = df_keyword["sentence"].apply(
    lambda sentence: (
        idx := sentence.index(word),
        sentence[np.max([0, idx - window_size]) : idx]
        + sentence[idx : np.min([len(sentence), idx + window_size])],
    )[1]
)
print(len(df_keyword))

In [None]:
def reading_matcher(furigana, word, word_classes):
    try:
        shifted_furigana = furigana[furigana.index(word) :]
    except ValueError:
        print(word)
        print(furigana)
        return -1
    found_reading = shifted_furigana[
        shifted_furigana.index("[") + 1 : shifted_furigana.index("]")
    ]
    # print(found_reading)
    for reading in word_classes:
        if found_reading.find(reading) != -1:
            return reading
    return -1

In [None]:
df_keyword["reading"] = df_keyword["furigana"].apply(
    lambda sentence: reading_matcher(sentence, word, word_classes)
)

In [None]:
# TODO: Improve the code for classifying words with furigana into one of the reading classes.

In [None]:
for word_class in word_classes:
    print(f"{word_class} {len(df_keyword[df_keyword['reading'] == word_class])}")
print("failures", len(df_keyword[df_keyword["reading"] == -1]))
df_keyword[df_keyword["reading"] == -1]

In [None]:
df_keyword = df_keyword[df_keyword["reading"] != -1]

In [None]:
word_id = tokenizer.encode(word, add_special_tokens=False)[0]
pad_size = 32
df_keyword["sentence-encoded"] = df_keyword["sentence-shorter"].apply(
    lambda sentence: tokenizer.encode(
        sentence,
        add_special_tokens=False,
        max_length=pad_size,
        truncation=True,
        padding="max_length",
    )
)
df_keyword["encoding-success"] = df_keyword["sentence-encoded"].apply(
    lambda encoding: word_id in encoding
)
print(len(df_keyword[~df_keyword["encoding-success"]]), "encoding failures")
df_keyword = df_keyword[df_keyword["encoding-success"]]
df_keyword = df_keyword.reset_index(drop=True)
df_keyword["keyword-index"] = df_keyword["sentence-encoded"].apply(
    lambda encoding: encoding.index(word_id)
)

In [None]:
df_keyword["keyword-index"] = df_keyword["sentence-encoded"].apply(
    lambda encoding: encoding.index(word_id)
)

In [None]:
encoding_stack = np.vstack(df_keyword["sentence-encoded"])

In [None]:
import torch

forward_pass = model.forward(torch.tensor(encoding_stack))

In [None]:
np.shape(forward_pass[0])

In [None]:
embs = []
for i in range(len(df_keyword)):
    embs.append(forward_pass[0][i][df_keyword.at[i, "keyword-index"]].detach().numpy())
embs = np.array(embs)

In [None]:
from sklearn.manifold import TSNE

# Run t-SNE on the contextualized embeddings:
mytsne_tokens = TSNE(
    n_components=2,
    early_exaggeration=12,
    verbose=2,
    metric="cosine",
    init="pca",
    n_iter=2000,
)
embs_tsne = mytsne_tokens.fit_transform(embs)

In [None]:
# Plot the keyword+context strings.
import japanize_matplotlib
import matplotlib.pyplot as plt

plt.rcParams["font.family"] = "VL Gothic"

colors = ["red", "black", "blue", "green"]
classes = list(df_keyword["reading"].unique())

fig = plt.figure(figsize=(6, 4))
cs = [
    colors[classes.index(df_keyword["reading"].iloc[i])] for i in range(len(df_keyword))
]

fig = plt.figure(figsize=(6, 4))
plt.scatter(embs_tsne[:, 0], embs_tsne[:, 1], s=1, color=cs)

plt.savefig("japanese-viz-bert-ctx-points-" + word + ".pdf", format="pdf")
plt.savefig("japanese-viz-bert-ctx-points-" + word + ".png", format="png")

plt.show()

In [None]:
# Plot the keyword+context strings.
# import matplotlib.pyplot as plt
# import japanize_matplotlib

# plt.rcParams["font.family"] = "VL Gothic"

# colors = ['red', 'black']
# classes = list(df_keyword['reading'].unique())
# fig = plt.figure(figsize=(50, 30))
# alltexts = list()
# for i, txt in enumerate(df_keyword['sentence-shorter']):
#     if i % 100 == 0:
#         print(i)
#     plt.scatter(embs_tsne[i,0], embs_tsne[i,1], s=0)
#     c = colors[classes.index(df_keyword['reading'].iloc[i])]
#     currtext = plt.text(embs_tsne[i,0], embs_tsne[i,1], txt, color=c)
#     #alltexts.append(currtext)

# plt.savefig('japanese-viz-bert-ctx-text-'+word+'.pdf', format='pdf')
# # print('now running adjust_text')
# #numiters = adjust_text(alltexts, autoalign=True, lim=50)
# #numiters = adjust_text(alltexts, autoalign=False, lim=50)
# #print('done adjust text, num iterations: ', numiters)
# #plt.savefig('viz-bert-ctx-values-viz750-adj.pdf', format='pdf')

# plt.show

## Handling out of vocab heteronyms

In [None]:
text = "その力士には金星が多くて大人気。"
text = "一時"

In [None]:
from yomikata.dictionary import DictionaryReader

DicReader = DictionaryReader()
DicReader.tagger(text)

Here we see a problem: The ambiguous word 大人気 is marked as two tokens. Does bert use the same tokenizer? (It uses unidic-lite)

In [None]:
from transformers import BertJapaneseTokenizer

tokenizer = BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-v2")

In [None]:
%time
text_encoded = tokenizer(
    text,
    add_special_tokens=False,
)
input_ids = text_encoded["input_ids"]
input_mask = text_encoded["attention_mask"]
print(input_ids)
print([tokenizer._convert_id_to_token(input_id) for input_id in input_ids])
tokenizer.decode(input_ids)

In [None]:
"一時" in list(tokenizer.vocab.keys())

In [None]:
tokenizer.vocab["一時"]

In [None]:
tokenizer.encode("一時")

In [None]:
len(tokenizer)

In [None]:
tokenizer.add_tokens(["一時"])

In [None]:
tokenizer.decode(tokenizer.encode(["一時"], add_special_tokens=False))

In [None]:
len(tokenizer)

Note this is not a contextual embedding yet, let's look at it after contextualizing

In [None]:
from transformers import BertModel

model = BertModel.from_pretrained("cl-tohoku/bert-base-japanese-v2")
model.eval();

In [None]:
text_encoded = tokenizer(
    text,
    # max_length=4,
    # truncation=True,
    # padding="max_length",
    return_tensors="pt",
    add_special_tokens=False,
)  # needs to be pytorch tensors
input_ids = text_encoded["input_ids"]
input_mask = text_encoded["attention_mask"]

print(input_ids.shape)

outputs = model.forward(input_ids=input_ids, attention_mask=input_mask)

print(outputs.last_hidden_state)
print(outputs.last_hidden_state.shape)

Now let's add a word to the vocabulary 

In [None]:
tokenizer.add_tokens(["大人気"])
model.resize_token_embeddings(
    len(tokenizer)
)  # Resize the dictionary size of the embedding layer

In [None]:
len(tokenizer)

In [None]:
%time
text_encoded = tokenizer(
    text,
    add_special_tokens=False,
)
input_ids = text_encoded["input_ids"]
input_mask = text_encoded["attention_mask"]
print(input_ids)
print([tokenizer._convert_id_to_token(input_id) for input_id in input_ids])
tokenizer.decode(input_ids)

In [None]:
text_encoded = tokenizer(
    text,
    # max_length=4,
    # truncation=True,
    # padding="max_length",
    return_tensors="pt",
    add_special_tokens=False,
)  # needs to be pytorch tensors
input_ids = text_encoded["input_ids"]
input_mask = text_encoded["attention_mask"]

print(input_ids.shape)

outputs = model.forward(input_ids=input_ids, attention_mask=input_mask)

print(outputs.last_hidden_state)
print(outputs.last_hidden_state.shape)