# **📌 Custom Dataset for Horse Breed Classification**

## **1️⃣ Understanding the Dataset Structure**
Unlike standard datasets where images are placed inside subfolders (one per class), this dataset is structured differently:
- **All images are in a single folder** (no subfolders).
- **Class labels are embedded in the filenames** (e.g., `01_005.png`).

---

## **2️⃣ Why Can't We Use `ImageFolder`?**
`torchvision.datasets.ImageFolder` assumes:

✅ **Each class has its own folder**  
❌ **This dataset does not follow that structure**  

Since our dataset does not follow the subfolder format, we **cannot** use `ImageFolder`. Instead, we will **extract class labels manually from filenames**.

---

## **3️⃣ What We Need to Do**
1️⃣ **Extract Image Paths** using `glob`.  
2️⃣ **Parse Labels from Filenames** (e.g., `01_001.png` → `01`).  
3️⃣ **Create a Custom Dataset Class** that loads images and their corresponding labels.


In [None]:
# # Please uncomment this out when you are running this lab on google colab!
# import os

# # Set KaggleHub cache to a directory inside /content/
# os.environ["KAGGLEHUB_CACHE"] = "/content/data"

In [None]:
import kagglehub

# Download latest version
dataset_path = kagglehub.dataset_download("olgabelitskaya/horse-breeds")

print("Path to dataset files:", dataset_path)

Path to dataset files: /kaggle/input/horse-breeds


In [None]:
import os
import glob

# Get all image paths
image_paths = glob.glob(os.path.join(dataset_path, "*.png"))

# Extract labels from filenames
image_labels = [int(os.path.basename(path).split("_")[0]) for path in image_paths]     # basename() extracts the filename from the path (you can do that manually also using split).
        # `../horse-breeds/01_001.png` → basename() → `01_001.png` → split("_")[0] → `01` → int() → 1

# Print example
print(f"Example Image Path: {image_paths[0]}")
print(f"Extracted Label: {image_labels[0]}")

Example Image Path: /kaggle/input/horse-breeds/01_103.png
Extracted Label: 1


##### Then just pass them to a dataset class similar to what we did before.

In [None]:
from torch.utils.data import Dataset
from PIL import Image

class HorsesDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):                                 # We just copied the SmokingDataset class
        self.image_paths = image_paths  # List of image paths
        self.labels = labels  # Corresponding labels
        self.transform = transform  # Transformations to apply

    def __len__(self):
        return len(self.image_paths)  # Total number of images

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]  # Get image path
        label = self.labels[idx]  # Get corresponding label

        # Load image
        image = Image.open(image_path)

        # Apply transformations (if any)
        if self.transform:
            image = self.transform(image)

        return image, label  # Return processed image and its label


dataset = HorsesDataset(image_paths, image_labels)                                        # <--  (Passed here)
dataset[0]

(<PIL.PngImagePlugin.PngImageFile image mode=RGBA size=256x256>, 1)