In [5]:
import os
import shutil
import logging
import time
from concurrent.futures import ThreadPoolExecutor

def setup_logging():
    """Sets up logging configuration."""
    logging.basicConfig(
        filename="batch_processing.log",
        level=logging.INFO,
        format="%(asctime)s - %(levelname)s - %(message)s"
    )

def copy_file(src, dst):
    """Copies a file from source to destination, logs errors if encountered."""
    try:
        shutil.copy2(src, dst)
    except Exception as e:
        logging.error(f"Error copying {src} to {dst}: {e}")

def create_federated_batches(source_base_folder, destination_base_folder, num_batches):
    """Creates federated learning batches from the given dataset."""
    try:
        # Define dataset mapping with correct label directories
        dataset_mapping = {"train": "det_train", "val": "det_val", "test": None}  # Test has no labels
        
        for dataset, label_subdir in dataset_mapping.items():
            image_source = os.path.join(source_base_folder, "images", dataset)
            label_source = os.path.join(source_base_folder, "labels", label_subdir) if label_subdir else None
            
            # Ensure image directory exists, skip test labels since they do not exist
            if not os.path.exists(image_source):
                logging.warning(f"Skipping {dataset} - Missing image folder: {image_source}")
                continue
            
            image_files = {os.path.splitext(f)[0]: f for f in os.listdir(image_source)}
            label_files = {os.path.splitext(f)[0]: f for f in os.listdir(label_source)} if label_source and os.path.exists(label_source) else {}
            
            # If test dataset, only use images
            matched_keys = sorted(image_files.keys() & label_files.keys()) if label_files else sorted(image_files.keys())
            
            if len(matched_keys) == 0:
                logging.error(f"No matching images and labels found in {dataset}, skipping.")
                continue
            
            total_files = len(matched_keys)
            batch_size = max(1, total_files // num_batches)  # Ensure at least one file per batch
            leftover = total_files % num_batches
            
            logging.info(f"{dataset}: {total_files} files, Batch size: {batch_size}, Leftover: {leftover}")
            
            start_time = time.time()
            index = 0
            
            with ThreadPoolExecutor(max_workers=8) as executor:  # Limit threads to prevent overload
                for i in range(num_batches):
                    batch_folder = os.path.join(destination_base_folder, f"batch_{i+1}", dataset)
                    batch_image_folder = os.path.join(batch_folder, "images")
                    batch_label_folder = os.path.join(batch_folder, "labels") if label_files else None
                    os.makedirs(batch_image_folder, exist_ok=True)
                    if batch_label_folder:
                        os.makedirs(batch_label_folder, exist_ok=True)
                    
                    current_batch_size = batch_size + (1 if i < leftover else 0)
                    batch_keys = matched_keys[index:index + current_batch_size]
                    index += current_batch_size
                    
                    for key in batch_keys:
                        img_src = os.path.join(image_source, image_files[key])
                        img_dst = os.path.join(batch_image_folder, image_files[key])
                        executor.submit(copy_file, img_src, img_dst)
                        
                        if label_files:
                            lbl_src = os.path.join(label_source, label_files[key])
                            lbl_dst = os.path.join(batch_label_folder, label_files[key])
                            executor.submit(copy_file, lbl_src, lbl_dst)
                    
                    logging.info(f"Batch {i+1} - {dataset}: {len(batch_keys)} files.")
            
            end_time = time.time()
            logging.info(f"{dataset} batch processing completed in {end_time - start_time:.2f} seconds.")
    except Exception as e:
        logging.error(f"Error creating batches: {e}")

if __name__ == "__main__":
    setup_logging()
    source_base_folder = r"C:\\Users\\sathish\\Downloads\\FL_ModelForAV\\data\\bdd100k"
    destination_base_folder = r"C:\\Users\\sathish\\Downloads\\FL_ModelForAV\\data\\bdd100_batch"
    num_batches = 20
    
    logging.info("Federated batch processing started.")
    create_federated_batches(source_base_folder, destination_base_folder, num_batches)
    logging.info("Federated batch processing completed.")
