In [None]:
import os
import numpy as np
from PIL import Image
from torchvision import transforms
from tqdm import tqdm

def preprocess_and_save_images(root, save_path, transform=None):
    # Define allowed image extensions
    valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')
    
    images = []
    labels = []

    # Process each class directory
    for class_id, class_name in tqdm(enumerate(sorted(os.listdir(root))), desc="Processing classes"):
        class_dir = os.path.join(root, class_name, "images")
        if os.path.isdir(class_dir):
            for image_name in os.listdir(class_dir):
                if image_name.lower().endswith(valid_extensions):
                    image_path = os.path.join(class_dir, image_name)
                    try:
                        # Load the image
                        with Image.open(image_path) as img:
                            img = img.convert('RGB')  # Ensure image is RGB

                        # Apply transformation
                        if transform is not None:
                            img = transform(img)

                        # Convert to NumPy array and append
                        img_array = np.array(img)
                        images.append(img_array)
                        labels.append(class_id)
                    except Exception as e:
                        print(f"Failed to load image {image_path}: {e}")

    # Convert lists to NumPy arrays
    images_np = np.array(images)
    labels_np = np.array(labels)

    # Save to .npy files
    np.save(os.path.join(save_path, 'images.npy'), images_np)
    np.save(os.path.join(save_path, 'labels.npy'), labels_np)
    print(f"Saved {len(images)} images and labels to {save_path}")

# Define your transformations
transform = transforms.Compose([
    transforms.Resize((256, 256)),  # Resize images
    transforms.ToTensor()
])

# Specify the root directory of your dataset and where to save the .npy files
dataset_root = '/graphics/scratch2/datasets/tiny-imagenet-200/tiny-imagenet-200'
npy_save_path = '.'

# Call the function
preprocess_and_save_images(dataset_root, npy_save_path, transform)
