In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
import os

print("--- Part 1: Splitting the Main Dataset (Memory-Efficient & Safe) ---")

# --- Configuration with your specific paths ---
BASE_DIR = r"D:\FedShield,Personal\wataiData\csv\CICIoT2023"
FULL_DATASET_PATH = os.path.join(BASE_DIR, 'full_dataset.csv')
OUTPUT_DIR = os.path.join(BASE_DIR, 'federated_data')

# --- Other Configuration ---
SERVER_DATASET_SIZE = 0.10  # 10% for the initial global model
CHUNK_SIZE = 100000         # Process 100,000 rows at a time

# --- Create Output Directory ---
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)
    print(f"Created directory: {OUTPUT_DIR}")

# --- Define Output File Paths ---
server_output_path = os.path.join(OUTPUT_DIR, 'server_df.csv')
clients_output_path = os.path.join(OUTPUT_DIR, 'clients_df.csv')

# --- Remove old files if they exist ---
if os.path.exists(server_output_path):
    os.remove(server_output_path)
if os.path.exists(clients_output_path):
    os.remove(clients_output_path)

# --- Process the large CSV in chunks ---
print(f"Loading and processing {FULL_DATASET_PATH} in chunks...")
header_written = False
chunk_num = 1
reader = pd.read_csv(FULL_DATASET_PATH, chunksize=CHUNK_SIZE)

for chunk in reader:
    print(f"  -> Processing chunk {chunk_num}...")
    chunk.dropna(inplace=True)

    if chunk.empty:
        chunk_num += 1
        continue
    
    # *** START OF THE FIX ***
    # Check the counts of each label in the current chunk
    label_counts = chunk['label'].value_counts()
    
    # If any label has fewer than 2 samples, stratification is not possible
    if (label_counts < 2).any():
        print(f"     ! Chunk {chunk_num} has a rare label with only 1 sample. Assigning full chunk to clients_df to avoid error.")
        # In this rare case, we can't split, so we'll just append the whole chunk to the larger clients file
        if not header_written:
             clients_chunk.to_csv(clients_output_path, index=False, mode='w', header=True)
             # We still need a header for the server file, so we write an empty dataframe with the correct columns
             pd.DataFrame(columns=chunk.columns).to_csv(server_output_path, index=False, mode='w', header=True)
             header_written = True
        else:
             chunk.to_csv(clients_output_path, index=False, mode='a', header=False)
        
        chunk_num += 1
        continue # Skip to the next chunk
    # *** END OF THE FIX ***

    # If the check passes, split the chunk as normal
    server_chunk, clients_chunk = train_test_split(
        chunk,
        test_size=(1 - SERVER_DATASET_SIZE),
        random_state=42,
        stratify=chunk['label']
    )

    if not header_written:
        server_chunk.to_csv(server_output_path, index=False, mode='w', header=True)
        clients_chunk.to_csv(clients_output_path, index=False, mode='w', header=True)
        header_written = True
    else:
        server_chunk.to_csv(server_output_path, index=False, mode='a', header=False)
        clients_chunk.to_csv(clients_output_path, index=False, mode='a', header=False)
    
    chunk_num += 1

print(f"\n✅ Part 1 Complete!")
print(f"Data has been split and saved chunk by chunk.")
print(f"Initial server data saved to: {server_output_path}")
print(f"Combined client data saved to: {clients_output_path}")