In [None]:
import os
import shutil
import glob
import pandas as pd
import tensorflow as tf
from multiprocessing import Pool

# Paths
data_dir = os.environ.get('DATA_DIR',
                          '/gpfs/home/zh283/StockPredictionDNN/Data')
parquet_dir = os.path.join(data_dir, 'parquet')
factor_xlsx = os.path.join(data_dir, 'factors_list.xlsx')
tfrecord_dir = os.path.join(data_dir, 'tfrecords')

# Clean out old TFRecords
if os.path.exists(tfrecord_dir):
    shutil.rmtree(tfrecord_dir)
os.makedirs(tfrecord_dir, exist_ok=True)

# Load characteristic list to define feature columns
chars = pd.read_excel(factor_xlsx)
FEATURE_COLS = chars.loc[chars['abr_jkp'].notna(), 'abr_jkp'].tolist()

# Other columns
META_COLS = [
    'permno', 'eom', 'me', 'size_grp', 'crsp_exchcd', 'ret', 'ret_exc'
]
WEIGHT_COLS = ['w_ew', 'w_vw']
LABEL_COLS = ['ret_exc_lead1m', 'ret_pct', 'ret_z', 'ret_invn']


# Helper for features
def _bytes_feature(v):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[v]))


def _float_feature(v):
    return tf.train.Feature(float_list=tf.train.FloatList(value=v))


def _int64_feature(v):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=v))


def serialize_example(rec):
    feat = {}
    # feature vector
    arr = rec['feat'] if 'feat' in rec else [rec[c] for c in FEATURE_COLS]
    feat['feat'] = _float_feature(arr)
    # weights, meta, label
    for c in WEIGHT_COLS + META_COLS + LABEL_COLS:
        val = rec[c]
        if isinstance(val, int): feat[c] = _int64_feature([val])
        else: feat[c] = _float_feature([float(val)])
    example = tf.train.Example(features=tf.train.Features(feature=feat))
    return example.SerializeToString()


def process_parquet_shard(args):
    parquet_file, variant, tfrecord_dir = args
    # derive a nice shard name from the filename
    fname = os.path.splitext(os.path.basename(parquet_file))[0]
    out_path = os.path.join(tfrecord_dir, f"{variant}-{fname}.tfrecord")
    writer = tf.io.TFRecordWriter(out_path)

    df = pd.read_parquet(parquet_file)
    for _, row in df.iterrows():
        writer.write(serialize_example(row.to_dict()))
    writer.close()


if __name__ == '__main__':
    variants = ['raw', 'pct', 'z', 'invn']
    tasks = []
    for variant in variants:
        pattern = os.path.join(parquet_dir, variant, 'year=*', '*.parquet')
        for parquet_file in glob.glob(pattern, recursive=True):
            tasks.append((parquet_file, variant, tfrecord_dir))

    # spin up one worker per core
    with Pool() as pool:
        pool.map(process_parquet_shard, tasks)