In [87]:
import pandas as pd
from tensorflow import keras
import os
import tensorflow as tf

In [161]:
'''
    返回census-income数据集
    @param train_data:训练集特征
    @param train_label:训练集标签
    @param test_data:测试集特征
    @param test_label:测试集标签
    @param validation_data:验证集特征
    @param validation_label:验证集标签
'''
def data_processing():
    # 数据集列表名
    column_names = ['age', 'class_worker', 'det_ind_code', 'det_occ_code', 'education', 'wage_per_hour', 'hs_college',
                    'marital_stat', 'major_ind_code', 'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member',
                    'unemp_reason', 'full_or_part_emp', 'capital_gains', 'capital_losses', 'stock_dividends',
                    'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat', 'det_hh_summ',
                    'instance_weight', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt',
                    'num_emp', 'fam_under_18', 'country_father', 'country_mother', 'country_self', 'citizenship',
                    'own_or_self', 'vet_question', 'vet_benefits', 'weeks_worked', 'year', 'income_50k']
    #读取训练数据
    train_df = pd.read_csv(
        '/Users/lizhen/Code/data_set/census-income/census-income.data.gz',
        delimiter=',',
        header=None,
        index_col=None,
        names=column_names
    )
    #读取测试数据
    other_df = pd.read_csv(
        '/Users/lizhen/Code/data_set/census-income/census-income.test.gz',
        delimiter=',',
        header=None,
        index_col=None,
        names=column_names
    )
    # 特征名字
    categorical_columns = ['class_worker', 'det_ind_code', 'det_occ_code', 'education', 'hs_college', 'major_ind_code',
                           'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', 'unemp_reason',
                           'full_or_part_emp', 'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat',
                           'det_hh_summ', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt',
                           'fam_under_18', 'country_father', 'country_mother', 'country_self', 'citizenship',
                           'vet_question']
    # First group of tasks according to the paper
    label_columns = ['income_50k', 'marital_stat']

    # One-hot encoding categorical columns
    categorical_columns = ['class_worker', 'det_ind_code', 'det_occ_code', 'education', 'hs_college', 'major_ind_code',
                           'major_occ_code', 'race', 'hisp_origin', 'sex', 'union_member', 'unemp_reason',
                           'full_or_part_emp', 'tax_filer_stat', 'region_prev_res', 'state_prev_res', 'det_hh_fam_stat',
                           'det_hh_summ', 'mig_chg_msa', 'mig_chg_reg', 'mig_move_reg', 'mig_same', 'mig_prev_sunbelt',
                           'fam_under_18', 'country_father', 'country_mother', 'country_self', 'citizenship',
                           'vet_question']
    train_raw_labels = train_df[label_columns]
    other_raw_labels = other_df[label_columns]
    transformed_train = pd.get_dummies(train_df.drop(label_columns, axis=1), columns=categorical_columns)
    transformed_other = pd.get_dummies(other_df.drop(label_columns, axis=1), columns=categorical_columns)
    #print(transformed_train.columns.values)
    #print(transformed_other.shape)
    #print(transformed_train.shape)
    
    transformed_other['det_hh_fam_stat_ Grandchild <18 ever marr not in subfamily'] = 0
    
    #获得标签量，并根据要求转换为one-hot向量
    train_income = keras.utils.to_categorical((train_raw_labels.income_50k == ' 50000+.').astype(int), num_classes=2)
    train_marital = keras.utils.to_categorical((train_raw_labels.marital_stat == ' Never married').astype(int), num_classes=2)
    other_income = keras.utils.to_categorical((other_raw_labels.income_50k == ' 50000+.').astype(int), num_classes=2)
    other_marital = keras.utils.to_categorical((other_raw_labels.marital_stat == ' Never married').astype(int), num_classes=2)
    
    # 转换为字典
    dict_outputs = {
        'income': train_income.shape[1],
        'marital': train_marital.shape[1]
    }
    dict_train_labels = {
        'income': train_income,
        'marital': train_marital
    }
    dict_other_labels = {
        'income': other_income,
        'marital': other_marital
    }
    output_info = [(dict_outputs[key], key) for key in sorted(dict_outputs.keys())]
    
    # 将测试集划分 为测试集和验证集 1:1的比例
    validation_indices = transformed_other.sample(frac=0.5, replace=False, random_state=1).index
    test_indices = list(set(transformed_other.index) - set(validation_indices))
    validation_data = transformed_other.iloc[validation_indices]
    validation_label = [dict_other_labels[key][validation_indices] for key in sorted(dict_other_labels.keys())]
    test_data = transformed_other.iloc[test_indices]
    test_label = [dict_other_labels[key][test_indices] for key in sorted(dict_other_labels.keys())]
    train_data = transformed_train
    train_label = [dict_train_labels[key] for key in sorted(dict_train_labels.keys())]

    return train_data, train_label, validation_data, validation_label, test_data, test_label

In [162]:
'''
    生成包含数据及标签的矩阵,因为在这里label是一个包含两个array的矩阵，因此需要有部分操作
    @param data 数据集合
    @param label 数据标签
    
    @return data 返回数据与数据集合一起的标签
'''
def generate_data(data, label):
    for i in range(0, len(label)):
        data = np.c_[data, label[i]]
    return data

In [163]:
train_data, train_label, val_data, val_label, test_data, test_label = data_processing()

In [164]:
train_data_np = generate_data(train_data, train_label)
test_data_np = generate_data(test_data, test_label)
val_data_np = generate_data(val_data, val_label)

In [166]:
output_dir = '/Users/lizhen/Code/data_set/census-income/generate_csv'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

In [167]:
'''
    @param output_dir 输出目录
    @param data 输出数据
    @param name_prefix 文件名前缀
    @param header
    @param n_parts 划分文件数大小
    @return filenames 生成文件名list
'''
def save_to_csv(output_dir, data, name_prefix, header=None, n_parts=10):
    path_format = os.path.join(output_dir, '{}_{:02d}.csv')
    filenames = []
    
    for file_idx, row_indices in enumerate(np.array_split(np.arange(len(data)), n_parts)):
        part_csv = path_format.format(name_prefix, file_idx)
        filenames.append(part_csv)
        with open(part_csv, 'wt', encoding='utf-8') as f:
            if header is not None:
                f.write(header + '\n')
            for row_index in row_indices:
                f.write(','.join([repr(col) for col in data[row_index]]))
                f.write('\n')
    return filenames

In [168]:
save_to_csv(output_dir, train_data_np, 'train_data', None, 10)

['/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_00.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_01.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_02.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_03.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_04.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_05.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_06.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_07.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_08.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_09.csv']

In [169]:
save_to_csv(output_dir, test_data_np, 'test_data', None, 10)
save_to_csv(output_dir, val_data_np, 'val_data', None,10)

['/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_00.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_01.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_02.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_03.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_04.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_05.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_06.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_07.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_08.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_09.csv']

In [88]:
dataset = tf.data.Dataset.list_files()

In [1]:
'''
    将csv数据每一行转换为特征值和标签值
    @param line line
    @param n_feature 特征纬度
    @return x 特征
    @return y 标签
'''
def parse_csv_line(line, n_feature=503):
    defs = [tf.constant(np.nan)] * n_feature
    parsed_field = tf.io.decode_csv(line, record_defaults=defs)
    x = tf.stack(parsed_field[0:-4])
    y_income = tf.stack(parsed_field[-4:-2])
    y_marital = tf.stack(parsed_field[-2:])
    return x, (y_income, y_marital)

In [110]:
'''
    将csv文件列表处理为TFDataSet
    @param filenames csv文件名列表
    @param n_reader interleave cycle_length
    @param batch_size batch_size
    @param n_parse_threads map的并行数
    @param shuffle_buffer_size shuffer_size
    @return dataset TFDataset
'''
def csv_reader_dataset(filenames, n_readers=5, batch_size=32, n_parse_threads=5, shuffle_buffer_size=10000):
    dataset = tf.data.Dataset.list_files(filenames)
    dataset = dataset.interleave(
        lambda filename : tf.data.TextLineDataset(filename),
        cycle_length=n_readers
    )
    dataset.shuffle(shuffle_buffer_size)
    dataset = dataset.map(parse_csv_line, num_parallel_calls=n_parse_threads)
    dataset = dataset.batch(batch_size)
    return dataset

In [171]:
train_dataset = csv_reader_dataset(['/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_00.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_01.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_02.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_03.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_04.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_05.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_06.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_07.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_08.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_09.csv'])

In [186]:
'''
    通过前缀获取文件列表
    @param source_dir 源文件目录
    @param prefix_name 前缀名
    @return result list
'''
def get_filename_by_prefix(source_dir,prefix_name):
    all_files = os.listdir(source_dir)
    results = []
    for filename in all_files:
        if filename.startswith(prefix_name):
            results.append(os.path.join(source_dir, filename))
    return results
source_dir='/Users/lizhen/Code/data_set/census-income/generate_csv/'
train_filenames = get_filename_by_prefix(source_dir, 'train')
valid_filenames = get_filename_by_prefix(source_dir, 'val')
test_filenames = get_filename_by_prefix(source_dir, 'test')

In [187]:
train_filenames

['/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_00.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_01.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_03.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_02.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_06.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_07.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_05.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_04.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_09.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/train_data_08.csv']

In [200]:
'''
    将数据serialize
    @param x feature
    @param y label
    @return serialize data
'''
def serialize(x, y):
    input_features = tf.train.FloatList(value=x)
    income_label = tf.train.Int64List(value=y[0:2])
    marital_label = tf.train.Int64List(value=y[2:])
    features = tf.train.Features(
        feature={
            'features':tf.train.Feature(float_list=input_features),
            'income_label':tf.train.Feature(int64_list=income_label),
            'marital_label':tf.train.Feature(int64_list=marital_label)
        }
    )
    example = tf.train.Example(features=features)
    return example.SerializeToString()

In [201]:
'''
    将csv数据转换为tfrecords数据
'''
def csv_dataset_to_tfrecords(base_filename, dataset, n_shards, step_per_shard, compression_type=None):
    options = tf.io.TFRecordOptions(compression_type=compression_type)
    filenames = []
    for shard_id in range(n_shards):
        filename_fullpath = '{}-{:05d}-of{:05d}'.format(base_filename, shard_id, n_shards)
        with tf.io.TFRecordWriter(filename_fullpath, options) as writer:
            for x_batch, y_batch in dataset.take(step_per_shard):
                for x_example, y_example in zip(x_batch, y_batch):
                    writer.write(serialize(x_example, y_example))
        filenames.append(filename_fullpath)
    return filenames

In [202]:
n_shards = 10
train_step_per_shard = train_data_np.shape[0] // 32//n_shards
test_step_per_shard = test_data_np.shape[0] // 32//n_shards
val_step_per_shard = val_data_np.shape[0] // 32//n_shards

In [203]:
train_step_per_shard

623

In [204]:
test_step_per_shard

155

In [205]:
val_step_per_shard

155

In [206]:
output_dir = '/Users/lizhen/Code/data_set/census-income/generate_tfrecords'
if not os.path.exists(output_dir):
    os.mkdir(output_dir)

train_basename = os.path.join(output_dir,'train')
test_basename = os.path.join(output_dir,'test')
val_basename = os.path.join(output_dir,'val')


In [207]:
valid_filenames

['/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_03.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_02.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_00.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_01.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_05.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_04.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_06.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_07.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_09.csv',
 '/Users/lizhen/Code/data_set/census-income/generate_csv/val_data_08.csv']

In [208]:
train_dataset = csv_reader_dataset(train_filenames)
test_dataset = csv_reader_dataset(test_filenames)
val_dataset = csv_reader_dataset(valid_filenames)

In [None]:
csv_dataset_to_tfrecords(train_basename, train_dataset, n_shards, train_step_per_shard,"GZIP")
csv_dataset_to_tfrecords(test_basename, test_dataset, n_shards, test_step_per_shard,"GZIP")
csv_dataset_to_tfrecords(val_basename, val_dataset, n_shards, val_step_per_shard,"GZIP")