In [1]:
import pandas as pd

In [2]:
# 加载标签字典
def load_tag_idx(filepath):
    tag_dict = dict()
    with open(filepath, "r") as fin:
        for line in fin:
            temp = line.strip().split(",")
            tag_dict[temp[0]] = int(temp[1])

    return tag_dict

In [3]:
# 转化原始数据文件
def convert_data(filepath, tag_dict, savepath):
    rst_list = []
    with open(filepath, "r") as fin:
        for line in fin.readlines():
            temp = line.strip()
            sep_idx = temp.find(":")
            tag = temp[:sep_idx]
            ctx = temp[sep_idx + 1:].strip()
            rst_list.append([tag_dict[tag], ctx])

    rst = pd.DataFrame(rst_list, columns=["tag", "ctx"])
    rst.sort_values(by="tag", inplace=True)
    rst.to_csv(savepath, header=False, index=False, sep="\t")

In [4]:
# 划分数据集
def split_data(filepath, savedir):
    org_data = pd.read_csv(filepath, header=None, sep="\t", names=["tags", "content"])
    train_data = pd.DataFrame(columns=["tag", "content"])
    valid_data = pd.DataFrame(columns=["tag", "content"])
    test_data = pd.DataFrame(columns=["tag", "content"])
    
    for tag in org_data["tags"].unique():
        temp_df = org_data[org_data["tags"] == tag]

        temp_train = temp_df.sample(frac=0.6, replace=False)
        train_data = train_data.append(temp_train)

        temp_left = temp_df[~temp_df.index.isin(temp_train.index)]

        temp_valid = temp_left.sample(frac=0.5, replace=False)
        valid_data = valid_data.append(temp_valid)

        test_data = test_data.append(temp_left[~temp_left.index.isin(temp_valid.index)])
    
    train_data = train_data.sample(frac=1)
    valid_data = valid_data.sample(frac=1)
    test_data = test_data.sample(frac=1)
    
    train_data.to_csv(savedir + "train.csv", header=False, index=False, sep="\t")
    valid_data.to_csv(savedir + "valid.csv", header=False, index=False, sep="\t")
    test_data.to_csv(savedir + "test.csv", header=False, index=False, sep="\t")

In [6]:
# 加载标签序号
tag_idx_file = "static/data/tag-idx.csv"
tag_dict = load_tag_idx(tag_idx_file)

In [7]:
# 转换原始数据文件
org_file = "static/data/huati_filter_final_posts_no_sge.txt"
savepath = "static/data/cvt_org_data.csv"
convert_data(org_file, tag_dict, savepath)

In [8]:
# 划分数据集
org_data_file = "static/data/cvt_org_data.csv"
savedir = "static/data/"
split_data(org_data_file, savedir)