In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import pandas as pd
import numpy as np
import os
from PIL import Image

In [2]:
dataset_root = "E:/DataSets/flowers"
labels = ["astilbe", "bellflower", "black_eyed_susan", "calendula", "california_poppy", "carnation", "common_daisy", "coreopsis", "daffodil", "dandelion", "iris", "magnolia", "rose", "sunflower", "tulip", "water_lily"]

In [3]:
data = []
for label in labels:
    img_dir = os.path.join(dataset_root, label)
    image_names = os.listdir(img_dir)
    for image_name in image_names:
        img_path = os.path.join(img_dir, image_name)
        data.append({"image_path": img_path, "label": label})
data = pd.DataFrame(data)
data["label"] = data["label"].map({label: i for i, label in enumerate(labels)})
data.head()

Unnamed: 0,image_path,label
0,E:/DataSets/flowers\astilbe\10091895024_a2ea04...,0
1,E:/DataSets/flowers\astilbe\1033455028_f0c6518...,0
2,E:/DataSets/flowers\astilbe\10373087134_927b53...,0
3,E:/DataSets/flowers\astilbe\1052212431_4963309...,0
4,E:/DataSets/flowers\astilbe\1052219251_d03970e...,0


In [4]:
class FlowerDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_path = self.data.iloc[index]["image_path"]
        label = self.data.iloc[index]["label"]
        image = Image.open(img_path)
        if self.transform:
            image = self.transform(image)
        return image, label

In [5]:
transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.RandomRotation(15),
    transforms.RandomHorizontalFlip(),
    transforms.Lambda(lambda x: x.convert("RGB") if x.mode != "RGB" else x),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
dataset = FlowerDataset(data, transform=transforms)
dataloader = DataLoader(dataset, batch_size=128, shuffle=True)

In [6]:
for i, (images, labels) in enumerate(dataloader):
    print(images.shape, labels.shape)

torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 224, 224]) torch.Size([128])
torch.Size([128, 3, 