# 改行・読点挿入モデル データセット作成


## Setup


In [None]:
import sqlite3
import numpy as np
import pandas as pd
import json


In [None]:
DB_FILE_NAME = "sidb.db"
TABLE_NAME = "docs"


In [None]:
conn = sqlite3.connect(DB_FILE_NAME)
df = pd.read_sql(f"SELECT * FROM {TABLE_NAME}", conn)
conn.close()


## Show data


In [None]:
print(df.info())

df_check = df.sample(n=3)
MAX_ONE_LINE_LENGTH = 100
for index in df_check.index:
    print("\n\n")
    for col in df_check.columns:
        print(
            col,
            str(df_check.loc[index, col])[:MAX_ONE_LINE_LENGTH] + "..."
            if len(str(df_check.loc[index, col])) > MAX_ONE_LINE_LENGTH
            else str(df_check.loc[index, col]),
        )


## Check data to be created

```py
['id', 'content', 'meta_info', 'sentence', 'clause', 'chunk', 'token', 'lf', 'lfp_lf', 'lfp_p', 'p']
```

- sentence = 文
- clause = 節
- chunk = 文節
- token = 形態素
- lf = 改行 (line feed)
- p = 読点 (punctuation)
- lfp_lf = 改行 (改行と読点の組み合わせ)
- lfp_p = 読点 (改行と読点の組み合わせ)


In [None]:
# for index in [1]:  # df.index:
#     content: str = df.loc[index, "content"]
#     sentence_data: list[dict] = json.loads(df.loc[index, "sentence"])
#     clause_data: list[dict] = json.loads(df.loc[index, "clause"])
#     chunk_data: list[dict] = json.loads(df.loc[index, "chunk"])
#     token_data: list[dict] = json.loads(df.loc[index, "token"])
#     lf_data: list[dict] = json.loads(df.loc[index, "lf"])
#     lfp_lf_data: list[dict] = json.loads(df.loc[index, "lfp_lf"])
#     lfp_p_data: list[dict] = json.loads(df.loc[index, "lfp_p"])
#     # p_data: list[dict]  =  df.loc[index, "p"] # データなし

#     lfp_lf_data = [x["end"] for x in lfp_lf_data]
#     lfp_p_data = [x["end"] for x in lfp_p_data]
#     sentence_data = [x["end"] for x in sentence_data]
#     for data in chunk_data:
#         print(content[data["begin"] : data["end"]], end="")
#         is_end_of_sentence = data["end"] in sentence_data
#         is_lf = data["end"] in lfp_lf_data
#         is_p = data["end"] in lfp_p_data and not is_end_of_sentence
#         if is_p:
#             print("、", end="")
#         if is_end_of_sentence:
#             print("。", end="")
#         if is_lf:
#             print()


## Create Dataset


In [None]:
SPECIAL_TOKEN = "[ANS]"
COLUMNS = ["input", "is_line_feed", "comma_period"]


def create_dataset(df, index, dataset):
    content: str = df.loc[index, "content"]
    chunk_data: list[dict] = json.loads(df.loc[index, "chunk"])
    sentence_data: list[dict] = json.loads(df.loc[index, "sentence"])
    lfp_lf_data: list[dict] = json.loads(df.loc[index, "lfp_lf"])
    lfp_p_data: list[dict] = json.loads(df.loc[index, "lfp_p"])

    lfp_lf_set = set([x["end"] for x in lfp_lf_data])
    lfp_p_set = set([x["end"] for x in lfp_p_data])
    sentence_set = set([x["end"] for x in sentence_data])

    new_data = []
    for i, data in enumerate(chunk_data[:-1]):
        is_end_of_sentence = data["end"] in sentence_set
        new_data.append(
            [
                content[data["begin"] : data["end"]]
                + SPECIAL_TOKEN
                + content[chunk_data[i + 1]["begin"] : chunk_data[i + 1]["end"]],
                int(data["end"] in lfp_lf_set),
                # 挿入なし=0, 読点=1, 句点=2
                1 if data["end"] in lfp_p_set and not is_end_of_sentence else 2
                if is_end_of_sentence
                else 0,
            ]
        )
    dataset = pd.DataFrame(
        np.vstack((dataset.values, np.array(new_data))), columns=COLUMNS
    )
    return dataset


# train dataset (all except last)
train_dataset = pd.DataFrame(data=None, index=None, columns=COLUMNS)
for index in range(len(df) - 1):
    train_dataset = create_dataset(df, index, train_dataset)
train_dataset.to_csv("train_dataset.csv")

# test dataset (last)
test_dataset = pd.DataFrame(data=None, index=None, columns=COLUMNS)
for index in [len(df) - 1]:
    test_dataset = create_dataset(df, index, test_dataset)
test_dataset.to_csv("test_dataset.csv")


## Confirm Created Dataset


In [None]:
created_dataset = pd.read_csv("test_dataset.csv")

for index in created_dataset.index:
    text, _ = created_dataset.loc[index, "input"].split(SPECIAL_TOKEN)
    print(text, end="")
    if created_dataset.loc[index, "comma_period"] == 1:
        print("、", end="")
    if created_dataset.loc[index, "comma_period"] == 2:
        print("。", end="")
    if created_dataset.loc[index, "is_line_feed"]:
        print()
