In [4]:
import os
import pandas as pd

In [7]:
DATA_LOC = '/root/xlcoder/MiniMind2-Small/dataset'

In [6]:
test_ratio = 0.2
random_seed = 42

In [9]:
df = pd.read_csv(os.path.join(DATA_LOC, "bbc-news-data.csv"), sep='\t')

train_parts = []
test_parts = []

for label, group in df.groupby("category"):
    # 随机打乱
    group = group.sample(frac=1, random_state=random_seed)
    
    # 计算测试集数量
    n_test = int(len(group) * test_ratio)
    
    # 划分
    test_part = group.iloc[:n_test]
    train_part = group.iloc[n_test:]
    
    test_parts.append(test_part)
    train_parts.append(train_part)

# 合并所有类别
train_df = pd.concat(train_parts).sample(frac=1, random_state=random_seed).reset_index(drop=True)
test_df  = pd.concat(test_parts).sample(frac=1, random_state=random_seed).reset_index(drop=True)

# 保存结果
train_df.to_csv(os.path.join(DATA_LOC, "bbc_train.csv"), index=False)
test_df.to_csv(os.path.join(DATA_LOC, "bbc_test.csv"), index=False)

print("训练集数量:", len(train_df))
print("测试集数量:", len(test_df))
print(train_df["category"].value_counts())
print(test_df["category"].value_counts())


训练集数量: 1781
测试集数量: 444
category
sport            409
business         408
politics         334
tech             321
entertainment    309
Name: count, dtype: int64
category
sport            102
business         102
politics          83
tech              80
entertainment     77
Name: count, dtype: int64


In [15]:
df = pd.read_csv(os.path.join(DATA_LOC, "bbc_test.csv"))

In [16]:
replace_dict={
    "business": "<CLS_B>",
    "entertainment": "<CLS_E>",
    "politics": "<CLS_P>",
    "sport": "<CLS_S>",
    "tech": "<CLS_T>"
}
df['category'] = df['category'].replace(replace_dict)
df

Unnamed: 0,category,filename,title,content
0,<CLS_S>,132.txt,Republic to face China and Italy,The Republic of Ireland have arranged friendl...
1,<CLS_E>,166.txt,TV station refuses adoption show,A TV station in the US has refused to show a ...
2,<CLS_B>,019.txt,India widens access to telecoms,India has raised the limit for foreign direct...
3,<CLS_T>,282.txt,Dublin hi-tech labs to shut down,"Dublin's hi-tech research laboratory, Media L..."
4,<CLS_B>,281.txt,Axa Sun Life cuts bonus payments,Life insurer Axa Sun Life has lowered annual ...
...,...,...,...,...
439,<CLS_E>,151.txt,Eminem secret gig venue revealed,Rapper Eminem is to play an intimate gig in L...
440,<CLS_S>,154.txt,Reyes tricked into Real admission,Jose Antonio Reyes has added to speculation l...
441,<CLS_S>,406.txt,Fuming Robinson blasts officials,"England coach Andy Robinson said he was ""livi..."
442,<CLS_T>,141.txt,US top of supercomputing charts,The US has pushed Japan off the top of the su...


In [17]:
df.to_csv(os.path.join(DATA_LOC, "bbc_test_std.csv"), index=False)