In [1]:
import os, shutil, re, random
from collections import defaultdict
import numpy as np
from glob import glob

random.seed(42)

def extract_id(filename):
    """Extract person ID from the filename prefix like 01__something"""
    match = re.match(r"(\d+)", filename)
    return match.group(1) if match else "unknown"

def collect_by_id(folder):
    """Group files by person ID"""
    groups = defaultdict(list)
    for path in glob(os.path.join(folder, "*.jpg")):
        fname = os.path.basename(path)
        pid = extract_id(fname)
        groups[pid].append(path)
    return groups

def select_samples(groups, max_total=1000, frames_per_id=8):
    """Select a limited number of samples per ID with randomness"""
    all_ids = list(groups.keys())
    random.shuffle(all_ids)

    selected = []
    for pid in all_ids:
        if len(selected) >= max_total:
            break
        frames = groups[pid][:]
        random.shuffle(frames)  # 🔀 Shuffle frame list
        selected.extend(frames[:frames_per_id])
    return selected[:max_total]


def split_by_id(real_dir, fake_dir, out_dir, max_samples=1000, frames_per_id=3, test_ratio=0.2):
    real_groups = collect_by_id(real_dir)
    fake_groups = collect_by_id(fake_dir)

    # Use only overlapping IDs to simulate deepfake scenario
    common_ids = set(real_groups.keys()).intersection(fake_groups.keys())
    real_groups = {k: real_groups[k] for k in common_ids}
    fake_groups = {k: fake_groups[k] for k in common_ids}

    ids = list(common_ids)
    random.shuffle(ids)
    n_test = int(len(ids) * test_ratio)
    train_ids, test_ids = ids[n_test:], ids[:n_test]

    def copy_files(files, dest):
        os.makedirs(dest, exist_ok=True)
        for f in files:
            shutil.copy2(f, os.path.join(dest, os.path.basename(f)))

    def gather(groups, id_set):
        return select_samples({k: v for k, v in groups.items() if k in id_set},
                              max_total=max_samples, frames_per_id=frames_per_id)

    # Select samples
    train_real = gather(real_groups, train_ids)
    train_fake = gather(fake_groups, train_ids)
    test_real = gather(real_groups, test_ids)
    test_fake = gather(fake_groups, test_ids)

    print(f"✅ Train: {len(train_real)} real + {len(train_fake)} fake")
    print(f"✅ Test:  {len(test_real)} real + {len(test_fake)} fake")

    # Save
    copy_files(train_real, os.path.join(out_dir, "train", "real"))
    copy_files(train_fake, os.path.join(out_dir, "train", "fake_generated"))
    copy_files(test_real, os.path.join(out_dir, "test", "real"))
    copy_files(test_fake, os.path.join(out_dir, "test", "fake_generated"))

    print(f"✅ Dataset saved to {out_dir}")

# Run it
split_by_id(
    real_dir="dataset/dfd_real",
    fake_dir="dataset/dfd_fake",
    out_dir="balanced_dataset_dfd",
    max_samples=1000,
    frames_per_id=8,
    test_ratio=0.2
)


✅ Train: 184 real + 184 fake
✅ Test:  40 real + 40 fake
✅ Dataset saved to balanced_dataset_dfd
