# 拆分数据集

针对What的List模式，进行数据集拆分，形成具有交叉验证或者随机划分的功能。通过test_size参数指定划分的比例。

```python
def split_dataset(X_data: pd.DataFrame, y_data: pd.DataFrame = None, test_size=0.2, n_trails=10,
                  cv: bool = False, shuffle: bool = False, random_state=None, save_dir=None):
    """
    数据划分。
    Args:
        X_data: 训练数据
        y_data: 监督数据
        test_size: 测试集比例
        n_trails: 尝试多少次寻找最佳数据集划分。
        cv: 是否是交叉验证，默认是False，当为True时，n_trails为交叉验证的n_fold
        shuffle: 是否进行随机打乱
        random_state: 随机种子
        save_dir: 信息保存的路径。

    Returns: 拆分之后的数据列表

    """
 ```

# 单中心数据

所有的数据，按照比例划分

In [1]:
import os
import pandas as pd

from onekey_algo import OnekeyDS as okds
from onekey_algo.custom.components.comp2 import split_dataset4sol
from onekey_algo import get_param_in_cwd

label_file = r'C:/Users/onekey/Desktop/demo/20250627/Demodata/label.csv'

# label_path指定你单中心的label文件路径
data = pd.read_csv(label_file)
rt = split_dataset4sol(data, data[get_param_in_cwd('event_col')], cv=False, save_dir='.',  n_trails=5, test_size=0.3)
x1, x2 = rt[0]

FileNotFoundError: [Errno 2] No such file or directory: 'C:/Users/onekey/Desktop/demo/20250627/Demodata/label.csv'

# 多中心数据

训练集，按比例划分

In [None]:
import os
import re
import shutil
import pandas as pd
from onekey_algo.custom.components.comp2 import split_dataset4sol
from onekey_algo import get_param_in_cwd
from onekey_algo.custom.utils import print_join_info

label_file = r'F:\wlx_OK_PDAC\OS_prediction-V3\data/label.csv'
data = pd.read_csv(label_file)
train_data = data[data['group'] == 'train']
test_data = data[data['group'] != 'train']

rt = split_dataset4sol(train_data, train_data[get_param_in_cwd('event_col')], cv=False, n_trails=40, test_size=0.3, save_dir='.', shuffle=True)
for idx, (train, val) in enumerate(rt):
    val['group'] = 'val'
    rnd = pd.concat([train, val, test_data], axis=0)
    display(rnd['group'].value_counts())
    rnd.to_csv(f'split_info/label-RND-{idx}.csv', index=False)
    rnd['ID'] = rnd['ID'].map(lambda x: x.replace('.gz', '.png'))
    tr = rnd[rnd['group'] == 'train']
    ts = rnd[rnd['group'] != 'train']
    tr[['ID', get_param_in_cwd('event_col'), get_param_in_cwd('duration_col')]].to_csv(f'split_info/train-RND-{idx}.txt', sep='\t', 
                                                                                       index=False, header=False)
    ts[['ID', get_param_in_cwd('event_col'), get_param_in_cwd('duration_col')]].to_csv(f'split_info/val-RND-{idx}.txt', sep='\t', 
                                                                                       index=False, header=False)
    
    tr['ID'] = tr['ID'].map(lambda x: x.replace('.png', '.npy'))
    tr[['ID', get_param_in_cwd('event_col'), get_param_in_cwd('duration_col')]].to_csv(f'split_info/train25d-RND-{idx}.txt', sep='\t', 
                                                                                       index=False, header=False)
    ts['ID'] = ts['ID'].map(lambda x: x.replace('.png', '.npy'))
    ts[['ID', get_param_in_cwd('event_col'), get_param_in_cwd('duration_col')]].to_csv(f'split_info/val25d-RND-{idx}.txt', sep='\t', 
                                                                                       index=False, header=False)

In [None]:
rnd