### 创建测试集

In [2]:
import numpy as np
import pandas as pd
import hashlib
import os

In [3]:
housing_path = "datasets/housing"
csv_path = os.path.join(housing_path, "housing.csv")
housing_data = pd.read_csv(csv_path)

In [7]:
def test_set_check(identifier, test_ratio, hash):
    """
    算出每个实例ID的哈希值,只保留其最后一个字节,如果该值小于等于51(约为256的20%),就将其放入测试集
    """
    return hash(np.int64(identifier)).digest()[-1] < 256 * test_ratio

In [8]:
def split_train_test_by_id(data, test_ratio, id_column, hash=hashlib.md5):
    """
    通过id拆分出测试集
    """
    ids = data[id_column]
    in_test_set = ids.apply(lambda id_: test_set_check(id_, test_ratio, hash))
    return data.loc[~in_test_set], data.loc[in_test_set]

#### 构建索引

In [23]:
housing_with_id['id'] = housing_data["经度"] * 1000 + housing_data["维度"]
train_set, test_set = split_train_test_by_id(housing_with_id,0.2,"id")

In [24]:
train_set.to_csv(path_or_buf=(os.path.join(housing_path, "housing_train.csv")),mode='w',index=False)
test_set.to_csv(path_or_buf=(os.path.join(housing_path, "housing_test.csv")),mode='w',index=False)

### 分层采样

In [22]:
housing_data["income_cat"] = np.ceil(housing_data["收入中位数"] / 1.5)
housing_data["income_cat"].where(housing_data["income_cat"]<5, 5.0, inplace=True)

In [25]:
from sklearn.model_selection import StratifiedShuffleSplit
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)


In [34]:
# 根据收入分类,进行分层采样
for train_index, test_index in split.split(housing_data, housing_data["income_cat"]):
    start_train_set = housing_data.loc[train_index]
    start_test_set = housing_data.loc[test_index]

In [28]:
housing_data["income_cat"].value_counts()	/	len(housing_data)

3.0    0.350581
2.0    0.318847
4.0    0.176308
5.0    0.114438
1.0    0.039826
Name: income_cat, dtype: float64

In [35]:
# 删除income_cat属性,使数据回到初始状态
for data_set in (start_train_set, start_test_set):
    data_set.drop(["income_cat"], axis=1, inplace=True)

In [37]:
# 保存分层采样取出的数据集
start_train_set.to_csv(path_or_buf=(os.path.join(housing_path, "housing_train_stratified_sampling.csv")),mode='w',index=False)
start_test_set.to_csv(path_or_buf=(os.path.join(housing_path, "housing_test_stratified_sampling.csv")),mode='w',index=False)