In [8]:
%%capture --no-stderr
%pip install pillow 

In [5]:
import os 
import zipfile

root_folder = './data'

def unzip_all_in_folder(folder_path):
    for root, dirs, files in os.walk(folder_path):
        for file in files:
            if file.endswith('.zip'):
                file_path = os.path.join(root, file)
                # Create a new folder with the same name as the zip file (without the .zip extension)
                extract_folder = os.path.join(root, os.path.splitext(file)[0])
                os.makedirs(extract_folder, exist_ok=True)
                
                with zipfile.ZipFile(file_path, 'r') as zip_ref:
                    zip_ref.extractall(extract_folder)
                print(f'Unzipped: {file_path} to {extract_folder}')
                os.remove(file_path)  # Delete the zip file after extracting its contents

unzip_all_in_folder(os.path.join(root_folder, 'real'))
unzip_all_in_folder(os.path.join(root_folder, 'fake'))

print("All zip files have been extracted!")

All zip files have been extracted!


In [13]:
import os
import random
import shutil
from PIL import Image
import json

# Directories
real_dir = 'data/real'
fake_dir = 'data/fake'
output_dir = 'sets'
train_dir = os.path.join(output_dir, 'train')
test_dir = os.path.join(output_dir, 'test')

# Create the output directories if they don't exist
os.makedirs(train_dir, exist_ok=True)
os.makedirs(test_dir, exist_ok=True)

# Remove resizing and directly save images
def save_image(image_path, output_path, quality=90):
    try:
        with Image.open(image_path) as img:
            img.save(output_path, 'JPEG', quality=quality)
    except Exception as e:
        print(f"Error processing {image_path}: {e}")

# Function to split images into train and test
def process_images(image_paths, num_test, num_train, dataset_type, category, output_test, output_train):
    random.shuffle(image_paths)  # Shuffle images to ensure random sampling
    
    test_images = image_paths[:num_test]  # Select num_test images for testing
    train_images = image_paths[num_test:num_test + num_train]  # Select num_train images for training
    
    # Copy images to test and train directories
    for img in test_images:
        output_path = os.path.join(output_test, f'{dataset_type}_{category}_{os.path.basename(img)}')
        save_image(img, output_path, quality=random.randint(50, 100))
    
    for img in train_images:
        output_path = os.path.join(output_train, f'{dataset_type}_{category}_{os.path.basename(img)}')
        save_image(img, output_path, quality=random.randint(50, 100))

# Recursively get all images in a directory
def get_images_in_directory(dir_path):
    image_paths = []
    for root, _, files in os.walk(dir_path):
        for file in files:
            if file.endswith(('png', 'jpg', 'jpeg')):
                image_paths.append(os.path.join(root, file))
    return image_paths

# Real and Fake categories setup
real_categories = ['coco', 'ffhq', 'imagenet', 'lsun']
fake_categories = ['generative_inpainting', 'glide', 'stylegan2', 'stylegan3', 'taming_transformer']

# Total number of images needed for real and fake categories
num_test_real = 625  # Test images per real category
num_train_real = 5625  # Train images per real category (9x the test set)
num_test_fake = 500  # Test images per fake category
num_train_fake = 4500  # Train images per fake category (9x the test set)

# Process real images (balanced between 4 categories)
for category in real_categories:
    real_images = get_images_in_directory(os.path.join(real_dir, category))
    if len(real_images) < num_test_real + num_train_real:
        print(f"Not enough images in category {category}. Needed: {num_test_real + num_train_real}, found: {len(real_images)}")
        continue  # Skip categories that don't have enough images
    process_images(real_images, num_test_real, num_train_real, 'real', category, test_dir, train_dir)

# Process fake images (balanced between 5 categories)
for category in fake_categories:
    fake_images = get_images_in_directory(os.path.join(fake_dir, category))
    if len(fake_images) < num_test_fake + num_train_fake:
        print(f"Not enough images in category {category}. Needed: {num_test_fake + num_train_fake}, found: {len(fake_images)}")
        continue  # Skip categories that don't have enough images
    process_images(fake_images, num_test_fake, num_train_fake, 'fake', category, test_dir, train_dir)

# Generate dataset structure for Hugging Face or JSON
def create_dataset_metadata(image_dir):
    dataset = []
    for root, _, files in os.walk(image_dir):
        for file in files:
            if file.endswith(('png', 'jpg', 'jpeg')):
                image_path = os.path.join(root, file)
                # Check if the file path contains 'real' or 'fake' in the name and assign the correct label
                if 'real' in file:
                    label = 'real'
                elif 'fake' in file:
                    label = 'fake'
                else:
                    label = 'unknown'
                data = {
                    'image_path': image_path,
                    'label': label
                }
                dataset.append(data)
    return dataset

# Create metadata for Hugging Face dataset structure
train_data = create_dataset_metadata(train_dir)
test_data = create_dataset_metadata(test_dir)

with open(os.path.join(output_dir, 'train_dataset.json'), 'w') as f:
    json.dump(train_data, f, indent=4)

with open(os.path.join(output_dir, 'test_dataset.json'), 'w') as f:
    json.dump(test_data, f, indent=1)


print("Balanced train and test dataset creation complete.")


Balanced train and test dataset creation complete.


In [16]:
import os
from collections import defaultdict

# Directories for train and test sets
train_dir = 'sets/train'
test_dir = 'sets/test'

# Function to summarize the dataset
def summarize_images(image_dir):
    summary = defaultdict(lambda: defaultdict(int))  # Nested defaultdict for category and type (real/fake)
    
    # Walk through each image in the directory and categorize it
    for root, _, files in os.walk(image_dir):
        for file in files:
            if file.endswith(('png', 'jpg', 'jpeg')):
                file_name = file.lower()
                if 'real' in file_name:
                    dataset_type = 'real'
                elif 'fake' in file_name:
                    dataset_type = 'fake'
                else:
                    dataset_type = 'unknown'  # If file is not classified as real or fake (just in case)
                
                # Extract category from the file name
                category = file_name.split('_')[1]  # Assumes file names are structured as: type_category_filename

                # Update count for the specific dataset type and category
                summary[dataset_type][category] += 1
    
    return summary

# Function to print the summary
def print_summary(summary, dataset_name):
    print(f"Summary for {dataset_name} dataset:")
    for dataset_type, categories in summary.items():
        total_images = sum(categories.values())
        print(f"  {dataset_type.capitalize()} images: {total_images}")
        for category, count in categories.items():
            print(f"    {category}: {count}")
    print()

# Summarize train and test datasets
train_summary = summarize_images("sets/train")
test_summary = summarize_images("sets/test")

# Print out the summaries
print_summary(train_summary, "Train")
print_summary(test_summary, "Test")


Summary for Train dataset:
  Real images: 22445
    lsun: 5570
    ffhq: 5625
    imagenet: 5625
    coco: 5625
  Fake images: 20614
    glide: 3910
    stylegan3: 3488
    stylegan2: 4408
    taming: 4392
    generative: 4416

Summary for Test dataset:
  Real images: 2498
    lsun: 623
    imagenet: 625
    ffhq: 625
    coco: 625
  Fake images: 2469
    glide: 491
    stylegan3: 485
    taming: 499
    generative: 498
    stylegan2: 496

