In [1]:
import os
import cv2
import numpy as np
from tqdm import tqdm
import random

In [2]:
# Set the desired local path for your data preprocessing
base_dir = "/media/hghosh/HGHOSH DISK/dataset"

# Create the base directory if it doesn't exist
os.makedirs(base_dir, exist_ok=True)

print(f"Directory created at: {base_dir}")


Directory created at: /media/hghosh/HGHOSH DISK/dataset


In [3]:
# Configure max images
MAX_IMAGES = 10000

# Function to download from URL
def download_from_url(url, category):
    category_dir = os.path.join(base_dir, category)
    os.makedirs(category_dir, exist_ok=True)

    # Download the file
    filename = os.path.join(category_dir, url.split('/')[-1])
    print(f"Downloading {category} dataset...")

    # Skip if already downloaded
    if os.path.exists(filename):
        print(f"File already exists: {filename}")
    else:
        try:
            response = requests.get(url, stream=True)
            total_size = int(response.headers.get('content-length', 0))

            with open(filename, 'wb') as f, tqdm(
                total=total_size, unit='B', unit_scale=True, unit_divisor=1024
            ) as bar:
                for data in response.iter_content(chunk_size=1024):
                    size = f.write(data)
                    bar.update(size)
        except Exception as e:
            print(f"Error downloading {url}: {e}")
            return False

    # Extract based on file extension
    print(f"Extracting {category} dataset...")
    try:
        if filename.endswith('.tgz') or filename.endswith('.tar.gz'):
            with tarfile.open(filename) as tar:
                tar.extractall(path=category_dir)
        elif filename.endswith('.zip'):
            with zipfile.ZipFile(filename) as zip_ref:
                zip_ref.extractall(category_dir)

        # Remove the compressed file to save space
        # Uncomment if you want to keep the original archives
        # os.remove(filename)

        print(f"{category} dataset extracted.")
        return True
    except Exception as e:
        print(f"Error extracting {filename}: {e}")
        return False

# Function to download from Kaggle
def download_from_kaggle(dataset, category):
    category_dir = os.path.join(base_dir, category)
    os.makedirs(category_dir, exist_ok=True)

    # Check if kaggle CLI is installed
    try:
        subprocess.run(["kaggle", "--version"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    except (subprocess.SubprocessError, FileNotFoundError):
        print("Kaggle CLI not found. Please install it with: pip install kaggle")
        print("And set up your API credentials: https://github.com/Kaggle/kaggle-api#api-credentials")
        return False

    # Check if API credentials are set up
    kaggle_dir = os.path.join(os.path.expanduser("~"), ".kaggle")
    kaggle_json = os.path.join(kaggle_dir, "kaggle.json")

    if not os.path.exists(kaggle_json):
        print("Kaggle API credentials not found.")
        print("Please create a token at https://www.kaggle.com/settings/account")
        print(f"Then create {kaggle_json} with your API key and username.")
        return False

    # Make sure permissions are correct
    try:
        os.chmod(kaggle_json, 0o600)
    except:
        print(f"Warning: Could not set permissions on {kaggle_json}")

    # Download the dataset
    try:
        print(f"Downloading {category} dataset from Kaggle...")
        subprocess.run(
            ["kaggle", "datasets", "download", "-d", dataset, "-p", category_dir],
            check=True
        )

        # Find the downloaded zip file
        zip_files = [f for f in os.listdir(category_dir) if f.endswith('.zip')]
        if not zip_files:
            print("No zip file found after download.")
            return False

        zip_file = os.path.join(category_dir, zip_files[0])

        # Extract the zip file
        print(f"Extracting {zip_file}...")
        with zipfile.ZipFile(zip_file, 'r') as zip_ref:
            zip_ref.extractall(category_dir)

        # Remove the zip file
        os.remove(zip_file)
        print(f"{category} dataset extracted.")
        return True
    except subprocess.CalledProcessError as e:
        print(f"Error downloading from Kaggle: {e}")
        return False
    except Exception as e:
        print(f"Error processing Kaggle dataset: {e}")
        return False


In [4]:
def preprocess_dataset(category):
    print(f"Preprocessing {category} dataset...")

    # Define directories
    category_dir = os.path.join(base_dir, category)
    processed_dir = os.path.join(base_dir, f"{category}_processed")
    pairs_dir = os.path.join(base_dir, f"{category}_pairs")
    # Check if the dataset has already been processed
    if os.path.exists(processed_dir) and os.path.exists(pairs_dir):
        print(f"{category} dataset already processed. Skipping...")
        return

    # Create output directories
    os.makedirs(processed_dir, exist_ok=True)
    os.makedirs(pairs_dir, exist_ok=True)
    os.makedirs(os.path.join(pairs_dir, "train"), exist_ok=True)
    os.makedirs(os.path.join(pairs_dir, "val"), exist_ok=True)

    # Find all image files recursively in the category directory
    image_files = []
    for root, _, files in os.walk(category_dir):
        for file in files:
            if file.lower().endswith(('.png', '.jpg', '.jpeg')):
                image_files.append(os.path.join(root, file))

    print(f"Found {len(image_files)} images for {category}")

    # If more than MAX_IMAGES, randomly select MAX_IMAGES
    if len(image_files) > MAX_IMAGES:
        print(f"Limiting to {MAX_IMAGES} random images")
        random.shuffle(image_files)
        image_files = image_files[:MAX_IMAGES]

    # Process each image
    processed_count = 0
    for img_path in tqdm(image_files, desc=f"Processing {category} images"):
        try:
            # Read the image
            img = cv2.imread(img_path)
            if img is None:
                continue

            # Resize to 256x256
            img = cv2.resize(img, (256, 256))

            # Create sketch using Canny edge detection
            gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
            edges = cv2.Canny(gray, 100, 200)

            # Convert edges to 3-channel image
            sketch = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)

            # Normalize images to range [-1, 1]
            img_norm = (img / 127.5) - 1
            sketch_norm = (sketch / 127.5) - 1

            # Save the original and sketch as a pair
            filename = os.path.basename(img_path).split('.')[0] + '.jpg'

            # Decide if this goes to train or validation (80/20 split)
            if np.random.random() < 0.8:
                split = "train"
            else:
                split = "val"

            pair_dir = os.path.join(pairs_dir, split)

            # Save the pair side by side
            pair = np.concatenate([sketch, img], axis=1)
            cv2.imwrite(os.path.join(pair_dir, filename), pair)

            # Also save individual files for flexibility
            cv2.imwrite(os.path.join(pair_dir, f"{filename.split('.')[0]}_sketch.jpg"), sketch)
            cv2.imwrite(os.path.join(pair_dir, f"{filename.split('.')[0]}_real.jpg"), img)

            processed_count += 1

        except Exception as e:
            print(f"Error processing {img_path}: {e}")

    print(f"Successfully processed {processed_count} images for {category}")
    return processed_count

In [5]:
def verify_image_counts():
    """Verify the total number of processed images and report counts"""
    print("\n--- Dataset Statistics ---")

    total_images = 0
    for category in os.listdir(base_dir):
        if category.endswith("_pairs"):
            pairs_dir = os.path.join(base_dir, category)
            train_dir = os.path.join(pairs_dir, "train")
            val_dir = os.path.join(pairs_dir, "val")

            if not os.path.exists(train_dir) or not os.path.exists(val_dir):
                print(f"Warning: Directories for {category} not found")
                continue

            train_count = len([f for f in os.listdir(train_dir) if f.endswith('.jpg') and not (f.endswith('_sketch.jpg') or f.endswith('_real.jpg'))])
            val_count = len([f for f in os.listdir(val_dir) if f.endswith('.jpg') and not (f.endswith('_sketch.jpg') or f.endswith('_real.jpg'))])

            print(f"{category[:-6]} dataset: {train_count} training images, {val_count} validation images, {train_count + val_count} total")
            total_images += train_count + val_count

    print(f"\nTotal images across all datasets: {total_images}")

In [6]:
# Main execution
if __name__ == "__main__":
    # List of categories to process
    all_categories = []

    # Manually specify categories if they exist locally
    for category in os.listdir(base_dir):
        if not category.endswith("_pairs") and not category.endswith("_processed"):
            all_categories.append(category)

    # Process each dataset
    for category in all_categories:
        preprocess_dataset(category)

    # Verify the counts
    verify_image_counts()

    print("\nAll datasets processed and organized!")
    print(f"Data is stored in: {base_dir}")

Preprocessing birds.pth dataset...
Found 0 images for birds.pth


Processing birds.pth images: 0it [00:00, ?it/s]

Successfully processed 0 images for birds.pth
Preprocessing cat dataset...





Found 29843 images for cat
Limiting to 10000 random images


Processing cat images: 100%|██████████| 10000/10000 [09:10<00:00, 18.16it/s]


Successfully processed 10000 images for cat
Preprocessing CUB_200_2011 dataset...
Found 11788 images for CUB_200_2011
Limiting to 10000 random images


Processing CUB_200_2011 images: 100%|██████████| 10000/10000 [05:35<00:00, 29.84it/s] 


Successfully processed 10000 images for CUB_200_2011
Preprocessing face dataset...
Found 202599 images for face
Limiting to 10000 random images


Processing face images: 100%|██████████| 10000/10000 [12:49<00:00, 12.99it/s]


Successfully processed 10000 images for face
Preprocessing shoes dataset...
Found 100091 images for shoes
Limiting to 10000 random images


Processing shoes images: 100%|██████████| 10000/10000 [07:20<00:00, 22.72it/s]


Successfully processed 10000 images for shoes
Preprocessing cats.pth dataset...
Found 0 images for cats.pth


Processing cats.pth images: 0it [00:00, ?it/s]


Successfully processed 0 images for cats.pth
Preprocessing faces.pth dataset...
Found 0 images for faces.pth


Processing faces.pth images: 0it [00:00, ?it/s]


Successfully processed 0 images for faces.pth
Preprocessing shoes_pairs.zip dataset...
Found 0 images for shoes_pairs.zip


Processing shoes_pairs.zip images: 0it [00:00, ?it/s]


Successfully processed 0 images for shoes_pairs.zip
Preprocessing flower.pth dataset...
Found 0 images for flower.pth


Processing flower.pth images: 0it [00:00, ?it/s]


Successfully processed 0 images for flower.pth

--- Dataset Statistics ---
birds.pth dataset: 0 training images, 0 validation images, 0 total
cat dataset: 8024 training images, 1976 validation images, 10000 total
CUB_200_2011 dataset: 7967 training images, 2033 validation images, 10000 total
face dataset: 8019 training images, 1981 validation images, 10000 total
shoes dataset: 6658 training images, 1920 validation images, 8578 total
cats.pth dataset: 0 training images, 0 validation images, 0 total
faces.pth dataset: 0 training images, 0 validation images, 0 total
shoes_pairs.zip dataset: 0 training images, 0 validation images, 0 total
flower.pth dataset: 0 training images, 0 validation images, 0 total

Total images across all datasets: 38578

All datasets processed and organized!
Data is stored in: /media/hghosh/HGHOSH DISK/dataset


https://mitmedialab.github.io/GAN-play/

In [7]:
import os
import random

def delete_excess_images(category, max_images):
    pairs_dir = os.path.join(base_dir, f"{category}_pairs")
    train_dir = os.path.join(pairs_dir, "train")
    val_dir = os.path.join(pairs_dir, "val")

    # Get all image files in train and val directories
    train_images = [f for f in os.listdir(train_dir) if f.endswith('.jpg') and not (f.endswith('_sketch.jpg') or f.endswith('_real.jpg'))]
    val_images = [f for f in os.listdir(val_dir) if f.endswith('.jpg') and not (f.endswith('_sketch.jpg') or f.endswith('_real.jpg'))]

    total_images = train_images + val_images

    # If total images exceed max_images, delete the excess one by one
    if len(total_images) > max_images:
        print(f"Deleting excess images for {category}...")

        # Randomly shuffle and select images to delete
        random.shuffle(total_images)
        images_to_delete = total_images[max_images:]

        for img in images_to_delete:
            if img in train_images:
                os.remove(os.path.join(train_dir, img))
                os.remove(os.path.join(train_dir, f"{img.split('.')[0]}_sketch.jpg"))
                os.remove(os.path.join(train_dir, f"{img.split('.')[0]}_real.jpg"))
            else:
                os.remove(os.path.join(val_dir, img))
                os.remove(os.path.join(val_dir, f"{img.split('.')[0]}_sketch.jpg"))
                os.remove(os.path.join(val_dir, f"{img.split('.')[0]}_real.jpg"))

            print(f"Deleted {img}")

        print(f"Deleted {len(images_to_delete)} images for {category}")

# Call delete_excess_images after preprocessing each category
for category in all_categories:
    preprocess_dataset(category)
    delete_excess_images(category, MAX_IMAGES)

# Verify the counts
verify_image_counts()

Preprocessing birds.pth dataset...
birds.pth dataset already processed. Skipping...
Preprocessing cat dataset...
cat dataset already processed. Skipping...
Preprocessing CUB_200_2011 dataset...
CUB_200_2011 dataset already processed. Skipping...
Preprocessing face dataset...
face dataset already processed. Skipping...
Preprocessing shoes dataset...
shoes dataset already processed. Skipping...
Preprocessing cats.pth dataset...
cats.pth dataset already processed. Skipping...
Preprocessing faces.pth dataset...
faces.pth dataset already processed. Skipping...
Preprocessing shoes_pairs.zip dataset...
shoes_pairs.zip dataset already processed. Skipping...
Preprocessing flower.pth dataset...
flower.pth dataset already processed. Skipping...

--- Dataset Statistics ---
birds.pth dataset: 0 training images, 0 validation images, 0 total
cat dataset: 8024 training images, 1976 validation images, 10000 total
CUB_200_2011 dataset: 7967 training images, 2033 validation images, 10000 total
face datas