# Animals-10 Dataset
Download the raw data from: https://www.kaggle.com/datasets/alessiocorrado99/animals10?resource=download \
Then use function below to convert .pt for faster process

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, TensorDataset
from torchvision.models import vit_b_16
import os
from sklearn.model_selection import train_test_split
from PIL import Image
import numpy as np

In [2]:
# Define dataset paths
data_dir = "./animals10"
torch_dataset_path = "./animals10Dataset.pt"

# Define class names
class_names = ["cane", "cavallo", "elefante", "farfalla", "gallina", "gatto", "mucca", "pecora", "ragno", "scoiattolo"]
class_to_idx = {class_name: idx for idx, class_name in enumerate(class_names)}

# Define transforms for data preprocessing
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# Process and store images in Torch tensor format
if not os.path.exists(torch_dataset_path):
    print("Processing dataset and saving to Torch tensor format...")
    images, labels = [], []
    
    for class_name in class_names:
        class_path = os.path.join(data_dir, class_name)
        if os.path.isdir(class_path):
            for img_name in os.listdir(class_path):
                if img_name.lower().endswith(('png', 'jpg', 'jpeg')):
                    img_path = os.path.join(class_path, img_name)
                    image = Image.open(img_path).convert("RGB")
                    image = transform(image)
                    images.append(image)
                    labels.append(class_to_idx[class_name])
    
    images = torch.stack(images)
    labels = torch.tensor(labels, dtype=torch.long)
    
    # Save to Torch tensor format
    torch.save({"images": images, "labels": labels}, torch_dataset_path)
    print("Dataset saved in Torch tensor format.")


Processing dataset and saving to Torch tensor format...
Dataset saved in Torch tensor format.
