In [12]:
from pathlib import Path

import polars as pl
from sklearn.model_selection import train_test_split

from embedder_train.processing.data_splitter import StratifiedDataSplitter

text_col = "Суть ситуации"
label_col = "Форма/тип ИФЛ/ТФЛ"
add_label_col = "Тип"

seed = 0
train_frac = 0.8
val_frac = 0.1
test_frac = 0.1

data_dir = Path("data")
data_out_dir = data_dir / "dataset-3"
dataset_path = data_dir / "Корректные первый файл исправленный.xlsx"
add_dataset_path = data_dir / "Корректные 24.xlsx"

labels_path = data_dir / "Тематики и id.xlsx"
labels_label_col = "Тематики на 26.06"
labels_id_col = "Id"

min_byte_len = 30


data_out_dir.mkdir(parents=True, exist_ok=True)

In [None]:
df = pl.read_excel(dataset_path, engine="openpyxl")
df = df.filter(~pl.col(text_col).is_null())
df.head()

In [None]:
add_df = pl.read_excel(add_dataset_path, engine="openpyxl", sheet_id=2)
add_df = add_df.filter(~pl.col(text_col).is_null())
add_df = add_df.rename({add_label_col: label_col})
add_df.head()

In [None]:
df = pl.concat([df.select(text_col, label_col), add_df.select(text_col, label_col)])
df.head()

In [None]:
labels_df = pl.read_excel(labels_path, engine="openpyxl")
labels_df.head()

In [None]:
df = df.join(labels_df, left_on=label_col, right_on=labels_label_col, how="left")
df = df.select(text_col, labels_id_col)
df = df.unique([text_col, labels_id_col])
df.head()

In [18]:
df = df.filter(pl.col(text_col).str.len_bytes() >= min_byte_len)
df = df.filter(~pl.col(text_col).str.contains("<"))

In [19]:
splitter = StratifiedDataSplitter(
    train_frac=train_frac,
    val_frac=val_frac,
    test_frac=test_frac,
    label_col=labels_id_col,
    seed=seed,
)

splitted = splitter.split(df)
train_df = splitted["train"]
val_df = splitted["val"]
test_df = splitted["test"]

In [22]:
train_df.write_parquet(data_out_dir / "train_data.parquet")
val_df.write_parquet(data_out_dir / "val_data.parquet")
test_df.write_parquet(data_out_dir / "test_data.parquet")

In [35]:
import plotly.express as px

px.histogram(x=train_df[labels_id_col].cast(str)).update_xaxes(
    categoryorder="total ascending"
)