# Edge SleepFM-Lite: Colab Setup\n\nThiết lập môi trường và biến dùng chung cho pipeline.

In [None]:
!pip install -q numpy pandas pyarrow h5py

In [None]:
from google.colab import drive

drive.mount('/content/drive')


In [None]:
import os

PROJECT_DIR = '/content/drive/MyDrive/edge-sleepfm-lite'
RAW_DIR = os.path.join(PROJECT_DIR, 'data', 'raw')
HDF5_DIR = os.path.join(PROJECT_DIR, 'data', 'canonical_hdf5')
PARQUET_DIR = os.path.join(PROJECT_DIR, 'data', 'training_store')
SPLIT_PATH = os.path.join(PROJECT_DIR, 'data', 'splits')
TEACHER_TARGET_DIR = os.path.join(PROJECT_DIR, 'data', 'teacher_targets')
CHECKPOINT_DIR = os.path.join(PROJECT_DIR, 'checkpoints')

print('PROJECT_DIR:', PROJECT_DIR)
print('RAW_DIR:', RAW_DIR)
print('HDF5_DIR:', HDF5_DIR)
print('PARQUET_DIR:', PARQUET_DIR)
print('SPLIT_PATH:', SPLIT_PATH)
print('TEACHER_TARGET_DIR:', TEACHER_TARGET_DIR)
print('CHECKPOINT_DIR:', CHECKPOINT_DIR)


In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sha2, substring, countDistinct, max as spark_max, when

spark = SparkSession.builder.appName("sleepfm-index-split").getOrCreate()

metadata_path = os.path.join(RAW_DIR, 'metadata.csv')
meta_df = spark.read.csv(metadata_path, header=True)

meta_df = meta_df.withColumn('subject_hash', sha2(col('subject_id'), 256))

hash_prefix = substring(col('subject_hash'), 1, 2)
split_expr = (
    when(hash_prefix <= 'aa', 'train')
    .when(hash_prefix <= 'cc', 'val')
    .otherwise('test')
)
meta_df = meta_df.withColumn('split', split_expr)

train_df = meta_df.filter(col('split') == 'train')
val_df = meta_df.filter(col('split') == 'val')
test_df = meta_df.filter(col('split') == 'test')

train_df.write.mode('overwrite').parquet(os.path.join(SPLIT_PATH, 'train'))
val_df.write.mode('overwrite').parquet(os.path.join(SPLIT_PATH, 'val'))
test_df.write.mode('overwrite').parquet(os.path.join(SPLIT_PATH, 'test'))

split_counts = (meta_df
    .groupBy('subject_id')
    .agg(countDistinct('split').alias('split_count'))
)
max_splits = split_counts.agg(spark_max('split_count').alias('max_split')).collect()[0]['max_split']
assert max_splits == 1, f'Subjects found in multiple splits: max split count = {max_splits}'
print('Verified: each subject_id is assigned to a single split.')
