In [None]:
import os
import random
import torch
from PIL import Image
from torchvision import transforms

# Define dataset paths
data_dir = "./animals10"
train_dataset_path = "./animals10_train_argumentation.pt"
val_dataset_path = "./animals10_val_argumentation.pt.pt"
test_dataset_path = "./animals10_test_argumentation.pt.pt"

# Define class names
class_names = ["cane", "cavallo", "elefante", "farfalla", "gallina", 
               "gatto", "mucca", "pecora", "ragno", "scoiattolo"]

# Define transforms
image_size = (224, 224)

# Augmentation + Normalization for training
train_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# Only normalization for validation and test
val_test_transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
])

# Split ratios
train_ratio = 0.7
val_ratio = 0.15
test_ratio = 0.15

# Process and save datasets
if not (os.path.exists(train_dataset_path) and 
        os.path.exists(val_dataset_path) and 
        os.path.exists(test_dataset_path)):
    
    print("Processing dataset with splits and saving...")

    image_label_pairs = []

    for class_name in class_names:
        class_path = os.path.join(data_dir, class_name)
        if os.path.isdir(class_path):
            for img_name in os.listdir(class_path):
                if img_name.lower().endswith(('png', 'jpg', 'jpeg')):
                    img_path = os.path.join(class_path, img_name)
                    label = class_names.index(class_name)
                    image_label_pairs.append((img_path, label))

    # Shuffle and split
    random.shuffle(image_label_pairs)
    total_size = len(image_label_pairs)
    train_end = int(total_size * train_ratio)
    val_end = train_end + int(total_size * val_ratio)

    train_data = image_label_pairs[:train_end]
    val_data = image_label_pairs[train_end:val_end]
    test_data = image_label_pairs[val_end:]

    def process_data(data, transform):
        images, labels = [], []
        for img_path, label in data:
            img = Image.open(img_path).convert("RGB")
            img_tensor = transform(img)
            images.append(img_tensor)
            labels.append(label)
        return torch.stack(images), torch.tensor(labels, dtype=torch.long)

    train_images, train_labels = process_data(train_data, train_transform)
    val_images, val_labels = process_data(val_data, val_test_transform)
    test_images, test_labels = process_data(test_data, val_test_transform)

    torch.save({"images": train_images, "labels": train_labels}, train_dataset_path)
    torch.save({"images": val_images, "labels": val_labels}, val_dataset_path)
    torch.save({"images": test_images, "labels": test_labels}, test_dataset_path)

    print(f"Saved {len(train_images)} training samples to {train_dataset_path}")
    print(f"Saved {len(val_images)} validation samples to {val_dataset_path}")
    print(f"Saved {len(test_images)} test samples to {test_dataset_path}")


Processing dataset with splits and saving...
